rapidfireai 0.9.9__py3-none-any.whl → 0.9.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

@@ -0,0 +1,371 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### RapidFire AI Tutorial Use Case: GRPO for Math Reasoning"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from rapidfireai import Experiment\n",
17
+ "from rapidfireai.automl import List, RFGridSearch, RFModelConfig, RFLoraConfig, RFGRPOConfig"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "### Load Dataset and Specify Train and Eval Partitions"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from datasets import load_dataset, Dataset\n",
34
+ "\n",
35
+ "def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
36
+ " data = load_dataset('openai/gsm8k', 'main')[split] \n",
37
+ " return data \n",
38
+ "\n",
39
+ "# Select a subset of the dataset for demo purposes\n",
40
+ "train_dataset = get_gsm8k_questions(split=\"train\").select(range(500))\n",
41
+ "eval_dataset = get_gsm8k_questions(split=\"test\").select(range(100))\n",
42
+ "train_dataset = train_dataset.shuffle(seed=42)\n",
43
+ "eval_dataset = eval_dataset.shuffle(seed=42)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "### Define Data Processing Function"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "def sample_formatting_function(row):\n",
60
+ " \"\"\"Function to preprocess each example from dataset\"\"\"\n",
61
+ "\n",
62
+ " def extract_hash_answer(text: str) -> str | None:\n",
63
+ " if \"####\" not in text:\n",
64
+ " return None\n",
65
+ " answer = text.split(\"####\")[1].strip()\n",
66
+ " try:\n",
67
+ " answer = answer.replace(\",\", \"\")\n",
68
+ " except:\n",
69
+ " return None\n",
70
+ " return answer\n",
71
+ " \n",
72
+ " SYSTEM_PROMPT = \"\"\"\n",
73
+ " Respond in the following format:\n",
74
+ " <reasoning>\n",
75
+ " ...\n",
76
+ " </reasoning>\n",
77
+ " <answer>\n",
78
+ " ...\n",
79
+ " </answer>\n",
80
+ " \"\"\"\n",
81
+ " return { # Return a conversation format dictionary\n",
82
+ " 'prompt': [\n",
83
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
84
+ " {'role': 'user', 'content': row['question']}\n",
85
+ " ],\n",
86
+ " 'question': row['question'],\n",
87
+ " 'answer': extract_hash_answer(row['answer'])\n",
88
+ " }"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "### Initialize Experiment"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "# Every experiment instance must be uniquely named\n",
105
+ "experiment = Experiment(experiment_name=\"exp1-math-reasoning\")"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "#### Define Custom Reward Functions"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 5,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
122
+ "\n",
123
+ " def extract_xml_answer(text: str) -> str:\n",
124
+ " answer = text.split(\"<answer>\")[-1]\n",
125
+ " answer = answer.split(\"</answer>\")[0]\n",
126
+ " return answer.strip()\n",
127
+ "\n",
128
+ " responses = [completion[0]['content'] for completion in completions]\n",
129
+ " q = prompts[0][-1]['content']\n",
130
+ " extracted_responses = [extract_xml_answer(r) for r in responses]\n",
131
+ " # x('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
132
+ " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
133
+ "\n",
134
+ "def int_reward_func(completions, **kwargs) -> list[float]:\n",
135
+ " \n",
136
+ " def extract_xml_answer(text: str) -> str:\n",
137
+ " answer = text.split(\"<answer>\")[-1]\n",
138
+ " answer = answer.split(\"</answer>\")[0]\n",
139
+ " return answer.strip()\n",
140
+ " responses = [completion[0]['content'] for completion in completions]\n",
141
+ " extracted_responses = [extract_xml_answer(r) for r in responses]\n",
142
+ " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
143
+ "\n",
144
+ "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
145
+ " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
146
+ " import re\n",
147
+ " pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
148
+ " responses = [completion[0][\"content\"] for completion in completions]\n",
149
+ " matches = [re.match(pattern, r) for r in responses]\n",
150
+ " return [0.5 if match else 0.0 for match in matches]\n",
151
+ "\n",
152
+ "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
153
+ " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
154
+ " import re\n",
155
+ " pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
156
+ " responses = [completion[0][\"content\"] for completion in completions]\n",
157
+ " matches = [re.match(pattern, r) for r in responses]\n",
158
+ " return [0.5 if match else 0.0 for match in matches]\n",
159
+ "\n",
160
+ "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
161
+ " def count_xml(text) -> float:\n",
162
+ " count = 0.0\n",
163
+ " if text.count(\"<reasoning>\\n\") == 1:\n",
164
+ " count += 0.125\n",
165
+ " if text.count(\"\\n</reasoning>\\n\") == 1:\n",
166
+ " count += 0.125\n",
167
+ " if text.count(\"\\n<answer>\\n\") == 1:\n",
168
+ " count += 0.125\n",
169
+ " count -= len(text.split(\"\\n</answer>\\n\")[-1])*0.001\n",
170
+ " if text.count(\"\\n</answer>\") == 1:\n",
171
+ " count += 0.125\n",
172
+ " count -= (len(text.split(\"\\n</answer>\")[-1]) - 1)*0.001\n",
173
+ " return count\n",
174
+ " contents = [completion[0][\"content\"] for completion in completions]\n",
175
+ " return [count_xml(c) for c in contents]"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {},
181
+ "source": [
182
+ "### Define Multi-Config Knobs for Model, LoRA, and GRPO Trainer using RapidFire AI Wrapper APIs"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "lora_config = RFLoraConfig(\n",
192
+ " r=32,\n",
193
+ " lora_alpha=64,\n",
194
+ " lora_dropout=0.05,\n",
195
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
196
+ " bias=\"none\"\n",
197
+ " )\n",
198
+ "\n",
199
+ "grpo_config1 = RFGRPOConfig(\n",
200
+ " learning_rate=5e-6,\n",
201
+ " warmup_ratio=0.1,\n",
202
+ " weight_decay=0.1,\n",
203
+ " max_grad_norm=0.1,\n",
204
+ " adam_beta1=0.9,\n",
205
+ " adam_beta2=0.99,\n",
206
+ " lr_scheduler_type = \"linear\",\n",
207
+ " per_device_train_batch_size=4,\n",
208
+ " gradient_accumulation_steps=4, \n",
209
+ " num_generations=8,\n",
210
+ " optim =\"adamw_8bit\",\n",
211
+ " num_train_epochs=2,\n",
212
+ " max_prompt_length=1024,\n",
213
+ " max_completion_length=1024,\n",
214
+ " logging_steps=2,\n",
215
+ " eval_steps=5\n",
216
+ ")\n",
217
+ "\n",
218
+ "grpo_config2 = grpo_config1.copy()\n",
219
+ "grpo_config2.learning_rate = 1e-5\n",
220
+ "\n",
221
+ "reward_funcs = [\n",
222
+ " correctness_reward_func,\n",
223
+ " int_reward_func,\n",
224
+ " strict_format_reward_func,\n",
225
+ " soft_format_reward_func,\n",
226
+ " xmlcount_reward_func,\n",
227
+ "]\n",
228
+ "\n",
229
+ "# List of 4 separate configs\n",
230
+ "config_set = List([\n",
231
+ " RFModelConfig(\n",
232
+ " model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
233
+ " peft_config=lora_config,\n",
234
+ " training_args=grpo_config1,\n",
235
+ " formatting_func=sample_formatting_function,\n",
236
+ " reward_funcs=reward_funcs,\n",
237
+ " model_kwargs={\"load_in_4bit\": True, \"device_map\": \"auto\", \"torch_dtype\": \"auto\", \"use_cache\": False},\n",
238
+ " tokenizer_kwargs={\"model_max_length\": 2048, \"padding_side\": \"left\", \"truncation\": True}\n",
239
+ " ),\n",
240
+ " RFModelConfig(\n",
241
+ " model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n",
242
+ " peft_config=lora_config,\n",
243
+ " training_args=grpo_config1,\n",
244
+ " formatting_func=sample_formatting_function,\n",
245
+ " reward_funcs=reward_funcs,\n",
246
+ " model_kwargs={\"load_in_4bit\": True, \"device_map\": \"auto\", \"torch_dtype\": \"auto\", \"use_cache\": False},\n",
247
+ " tokenizer_kwargs={\"model_max_length\": 2048, \"padding_side\": \"left\", \"truncation\": True}\n",
248
+ " ),\n",
249
+ " RFModelConfig(\n",
250
+ " model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n",
251
+ " peft_config=lora_config,\n",
252
+ " training_args=grpo_config2,\n",
253
+ " formatting_func=sample_formatting_function,\n",
254
+ " reward_funcs=reward_funcs,\n",
255
+ " model_kwargs={\"load_in_4bit\": True, \"device_map\": \"auto\", \"torch_dtype\": \"auto\", \"use_cache\": False},\n",
256
+ " tokenizer_kwargs={\"model_max_length\": 2048, \"padding_side\": \"left\", \"truncation\": True}\n",
257
+ " ),\n",
258
+ " RFModelConfig(\n",
259
+ " model_name=\"Qwen/Qwen2.5-7B-Instruct\",\n",
260
+ " peft_config=lora_config,\n",
261
+ " training_args=grpo_config1,\n",
262
+ " formatting_func=sample_formatting_function,\n",
263
+ " reward_funcs=reward_funcs,\n",
264
+ " model_kwargs={\"load_in_4bit\": True, \"device_map\": \"auto\", \"torch_dtype\": \"auto\", \"use_cache\": False},\n",
265
+ " tokenizer_kwargs={\"model_max_length\": 2048, \"padding_side\": \"left\", \"truncation\": True}\n",
266
+ " ),\n",
267
+ "])"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "markdown",
272
+ "metadata": {},
273
+ "source": [
274
+ "#### Define Model Creation Function"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "def sample_create_model(model_config):\n",
284
+ " \"\"\"Function to create model object for any given config; must return tuple of (model, tokenizer)\"\"\"\n",
285
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
286
+ " \n",
287
+ " model_name = model_config[\"model_name\"]\n",
288
+ " model_kwargs = model_config[\"model_kwargs\"]\n",
289
+ " tokenizer_kwargs = model_config[\"tokenizer_kwargs\"]\n",
290
+ " return (\n",
291
+ " AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs),\n",
292
+ " AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)\n",
293
+ " )"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {},
299
+ "source": [
300
+ "#### Generate Config Group"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "# Simple grid search across all sets of config knob values = 4 combinations in total\n",
310
+ "config_group = RFGridSearch(\n",
311
+ " configs=config_set,\n",
312
+ " trainer_type=\"GRPO\",\n",
313
+ ")"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": [
320
+ "### Run Multi-Config Training"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "# Launch training of all configs in the config_group with swap granularity of 4 chunks\n",
330
+ "experiment.run_fit(config_group, sample_create_model, train_dataset, eval_dataset, num_chunks=4, seed=42)"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "metadata": {},
336
+ "source": [
337
+ "### End Current Experiment"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "experiment.end()"
347
+ ]
348
+ }
349
+ ],
350
+ "metadata": {
351
+ "kernelspec": {
352
+ "display_name": "oss_venv",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.10.12"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 2
371
+ }