fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/lora.py ADDED
@@ -0,0 +1,207 @@
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ import mlx.optimizers as optim
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
+
10
+ from .prompt_utils import apply_chat_template
11
+ from .trainer import Dataset, Trainer, save_adapter
12
+ from .trainer.utils import apply_lora_layers, find_all_linear_names, get_peft_model
13
+ from .utils import load, load_image_processor
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def custom_print(*args, **kwargs):
20
+ tqdm.write(" ".join(map(str, args)), **kwargs)
21
+
22
+
23
+ def main(args):
24
+ logger.info(f"\033[32mLoading model from {args.model_path}\033[0m")
25
+ model, processor = load(
26
+ args.model_path, processor_config={"trust_remote_code": True}
27
+ )
28
+ config = model.config.__dict__
29
+ image_processor = load_image_processor(args.model_path)
30
+
31
+ logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
32
+ dataset = load_dataset(args.dataset, split=args.split)
33
+
34
+ if "messages" not in dataset.column_names:
35
+ raise ValueError("Dataset must have a 'messages' column")
36
+ if "images" not in dataset.column_names:
37
+ raise ValueError("Dataset must have an 'images' column")
38
+
39
+ if args.apply_chat_template:
40
+ logger.info(f"\033[32mApplying chat template to the dataset\033[0m")
41
+
42
+ def process_data(examples):
43
+ if config["model_type"] == "pixtral":
44
+ conversations = apply_chat_template(
45
+ config=config,
46
+ processor=processor,
47
+ prompt=examples["messages"],
48
+ return_messages=True,
49
+ )
50
+ examples["messages"] = [
51
+ json.dumps(item, ensure_ascii=False) for item in conversations
52
+ ]
53
+ else:
54
+ examples["messages"] = apply_chat_template(
55
+ config=config,
56
+ processor=processor,
57
+ prompt=examples["messages"],
58
+ return_messages=True,
59
+ )
60
+ return examples
61
+
62
+ dataset = dataset.map(process_data)
63
+
64
+ dataset = Dataset(
65
+ dataset,
66
+ config,
67
+ processor,
68
+ image_processor=image_processor,
69
+ image_resize_shape=args.image_resize_shape,
70
+ )
71
+
72
+ adapter_path = args.adapter_path
73
+ if adapter_path:
74
+ logger.info(f"\033[32mResuming from adapter path {adapter_path}\033[0m")
75
+ logger.info(
76
+ f"\033[32mLora rank, alpha, and dropout will be loaded from adapter_config.json file\033[0m"
77
+ )
78
+
79
+ model = apply_lora_layers(model, adapter_path)
80
+
81
+ else:
82
+ logger.info(f"\033[32mSetting up LoRA\033[0m")
83
+
84
+ list_of_modules = find_all_linear_names(model.language_model)
85
+ model = get_peft_model(
86
+ model,
87
+ list_of_modules,
88
+ rank=args.lora_rank,
89
+ alpha=args.lora_alpha,
90
+ dropout=args.lora_dropout,
91
+ )
92
+
93
+ logger.info(f"\033[32mSetting up optimizer\033[0m")
94
+ optimizer = optim.Adam(learning_rate=args.learning_rate)
95
+
96
+ logger.info(f"\033[32mSetting up trainer\033[0m")
97
+ trainer = Trainer(model, optimizer)
98
+
99
+ model.train()
100
+
101
+ # Training loop
102
+ logger.info(f"\033[32mTraining model\033[0m")
103
+ for epoch in range(args.epochs):
104
+ if args.steps == 0:
105
+ args.steps = len(dataset) // args.batch_size
106
+
107
+ progress_bar = tqdm(range(args.steps), position=0, leave=True)
108
+ for i in progress_bar:
109
+ loss = trainer.train_step(
110
+ dataset[i * args.batch_size : (i + 1) * args.batch_size]
111
+ )
112
+ # Update progress bar
113
+ progress_bar.update(1)
114
+ progress_bar.set_postfix(
115
+ {"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"}
116
+ )
117
+
118
+ if i % args.print_every == 0:
119
+ # Log additional information
120
+ custom_print(
121
+ {
122
+ "Epoch": epoch,
123
+ "Step": i,
124
+ "Loss": f"{loss.item():.4f}",
125
+ }
126
+ )
127
+ # Save the interim adapter after each epoch except the last.
128
+ if args.save_after_epoch and (epoch < (args.epochs - 1)):
129
+ head, tail = os.path.split(args.output_path)
130
+ save_adapter(model, head + os.sep + "epoch_" + str(epoch) + "_" + tail)
131
+
132
+ # Save the adapter
133
+ save_adapter(model, args.output_path)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description="Train NanoLLaVA model")
138
+ parser.add_argument(
139
+ "--model-path",
140
+ type=str,
141
+ default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
142
+ help="Path to the pre-trained model",
143
+ )
144
+ parser.add_argument(
145
+ "--dataset", type=str, required=True, help="Path to the dataset"
146
+ )
147
+ parser.add_argument(
148
+ "--split", type=str, default="train", help="Split to use for training"
149
+ )
150
+ parser.add_argument(
151
+ "--image-resize-shape",
152
+ type=int,
153
+ nargs=2,
154
+ default=None,
155
+ help="Resize images to this shape",
156
+ )
157
+ parser.add_argument(
158
+ "--apply-chat-template",
159
+ action="store_false",
160
+ help="Apply chat template to the dataset",
161
+ )
162
+ parser.add_argument(
163
+ "--learning-rate",
164
+ type=float,
165
+ default=1e-4,
166
+ help="Learning rate for the optimizer",
167
+ )
168
+ parser.add_argument(
169
+ "--batch-size", type=int, default=1, help="Batch size for training"
170
+ )
171
+ parser.add_argument(
172
+ "--epochs", type=int, default=1, help="Number of epochs to train"
173
+ )
174
+ parser.add_argument(
175
+ "--steps", type=int, default=0, help="Number of steps per epoch"
176
+ )
177
+ parser.add_argument(
178
+ "--print-every", type=int, default=10, help="Print loss every n steps"
179
+ )
180
+ parser.add_argument(
181
+ "--lora-alpha",
182
+ type=float,
183
+ default=0.1,
184
+ help="LoRA scaling factor (alpha / rank)",
185
+ )
186
+ parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank")
187
+ parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
188
+ parser.add_argument(
189
+ "--output-path",
190
+ type=str,
191
+ default="adapters",
192
+ help="Path to save the trained adapter",
193
+ )
194
+ parser.add_argument(
195
+ "--adapter-path",
196
+ type=str,
197
+ default=None,
198
+ help="Load path to resume training from a previously saved adapter",
199
+ )
200
+ parser.add_argument(
201
+ "--save-after-epoch",
202
+ action="store_true",
203
+ help="Save interim versions of adapter files after each epoch",
204
+ )
205
+
206
+ args = parser.parse_args()
207
+ main(args)
File without changes
@@ -0,0 +1,2 @@
1
+ from .aya_vision import LanguageModel, Model, VisionModel
2
+ from .config import ModelConfig, TextConfig, VisionConfig
@@ -0,0 +1,188 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ class AyaVisionMultiModalProjector(nn.Module):
14
+ def __init__(self, config: ModelConfig):
15
+ super().__init__()
16
+ self.config = config
17
+ self.downsample_factor = config.downsample_factor
18
+ self.alignment_intermediate_size = getattr(
19
+ config, "alignment_intermediate_size", config.text_config.hidden_size
20
+ )
21
+ if config.model_type == "aya_vision":
22
+ self.layernorm = nn.LayerNorm(
23
+ config.vision_config.hidden_size * (config.downsample_factor**2),
24
+ eps=config.adapter_layer_norm_eps,
25
+ )
26
+
27
+ self.linear_1 = nn.Linear(
28
+ config.vision_config.hidden_size * (config.downsample_factor**2),
29
+ self.alignment_intermediate_size,
30
+ bias=True,
31
+ )
32
+
33
+ self.act = nn.SiLU() # SwiGLU uses SiLU activation
34
+
35
+ # For SwiGLU, project down to half size since we split intermediate dim
36
+ self.linear_2 = nn.Linear(
37
+ self.alignment_intermediate_size // 2,
38
+ config.text_config.hidden_size,
39
+ bias=True,
40
+ )
41
+
42
+ def __call__(self, image_features):
43
+ image_features = self.pixel_shuffle(image_features)
44
+ if self.config.model_type == "aya_vision":
45
+ image_features = self.layernorm(image_features)
46
+ hidden_states = self.linear_1(image_features)
47
+
48
+ # Split along last dimension and apply SwiGLU
49
+ x, gate = mx.split(hidden_states, 2, axis=-1)
50
+ hidden_states = self.act(gate) * x
51
+
52
+ hidden_states = self.linear_2(hidden_states)
53
+ return hidden_states
54
+
55
+ def pixel_shuffle(self, image_features): # B, S, D
56
+ batch_size, seq_length, feature_dim = image_features.shape
57
+ height = width = int(seq_length**0.5)
58
+ image_features = image_features.reshape(
59
+ image_features.shape[0], width, height, -1
60
+ )
61
+ channels = image_features.shape[-1]
62
+ image_features = image_features.reshape(
63
+ batch_size,
64
+ width,
65
+ int(height / self.downsample_factor),
66
+ int(channels * self.downsample_factor),
67
+ )
68
+ image_features = image_features.transpose(0, 2, 1, 3)
69
+ image_features = image_features.reshape(
70
+ batch_size,
71
+ int(height / self.downsample_factor),
72
+ int(width / self.downsample_factor),
73
+ -1,
74
+ )
75
+ image_features = image_features.transpose(0, 2, 1, 3)
76
+ return image_features
77
+
78
+
79
+ class Model(nn.Module):
80
+ def __init__(self, config: ModelConfig):
81
+ super().__init__()
82
+ self.config = config
83
+ self.vision_tower = VisionModel(config.vision_config)
84
+ self.language_model = LanguageModel(config.text_config)
85
+ self.multi_modal_projector = AyaVisionMultiModalProjector(config)
86
+ self.vision_feature_layer = config.vision_feature_layer
87
+ self.vision_feature_select_strategy = config.vision_feature_select_strategy
88
+
89
+ def get_input_embeddings(
90
+ self,
91
+ input_ids: Optional[mx.array] = None,
92
+ pixel_values: Optional[mx.array] = None,
93
+ **kwargs,
94
+ ):
95
+ if pixel_values is None:
96
+ return InputEmbeddingsFeatures(
97
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
98
+ )
99
+
100
+ # Get the input embeddings from the language model
101
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
102
+
103
+ spatial_shapes = kwargs.get("spatial_shapes", None)
104
+ # Get the ouptut hidden states from the vision model
105
+ *_, hidden_states = self.vision_tower(
106
+ pixel_values.transpose(0, 2, 3, 1),
107
+ spatial_shapes=spatial_shapes,
108
+ output_hidden_states=True,
109
+ )
110
+
111
+ # Select the hidden states from the desired layer
112
+ selected_image_feature = hidden_states[self.vision_feature_layer]
113
+
114
+ if self.vision_feature_select_strategy == "default":
115
+ selected_image_feature = selected_image_feature[:, 1:]
116
+ elif self.vision_feature_select_strategy == "full":
117
+ selected_image_feature = selected_image_feature
118
+ else:
119
+ raise ValueError(
120
+ "Unexpected feature selection strategy: "
121
+ f"{self.vision_feature_select_strategy}"
122
+ )
123
+
124
+ # Pass image features through the multi-modal projector
125
+ image_features = self.multi_modal_projector(selected_image_feature)
126
+
127
+ # Insert special image tokens in the input_ids
128
+ final_inputs_embeds = self._merge_input_ids_with_image_features(
129
+ image_features, inputs_embeds, input_ids
130
+ )
131
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
132
+
133
+ def _merge_input_ids_with_image_features(
134
+ self, image_features, inputs_embeds, input_ids
135
+ ):
136
+ image_token_index = self.config.image_token_index
137
+
138
+ # Positions of <image> tokens in input_ids, assuming batch size is 1
139
+ image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
140
+ num_images, _, _, vision_hidden_size = image_features.shape
141
+
142
+ reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)
143
+
144
+ # cast to the dtype of the input_embeds to support quantized models
145
+ reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
146
+ inputs_embeds.dtype
147
+ )
148
+ inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
149
+ return inputs_embeds
150
+
151
+ @property
152
+ def layers(self):
153
+ return self.language_model.model.layers
154
+
155
+ def __call__(
156
+ self,
157
+ input_ids: mx.array,
158
+ pixel_values: mx.array,
159
+ mask: mx.array,
160
+ cache=None,
161
+ **kwargs,
162
+ ):
163
+
164
+ input_embeddings_features = self.get_input_embeddings(
165
+ input_ids, pixel_values, **kwargs
166
+ )
167
+ logits = self.language_model(
168
+ input_ids,
169
+ cache=cache,
170
+ inputs_embeds=input_embeddings_features.inputs_embeds,
171
+ )
172
+ return logits
173
+
174
+ def sanitize(self, weights):
175
+ def transform_key(key):
176
+ if "model.vision_tower" in key:
177
+ key = key.replace("model.vision_tower", "vision_tower")
178
+ if "model.multi_modal_projector" in key:
179
+ key = key.replace(
180
+ "model.multi_modal_projector", "multi_modal_projector"
181
+ )
182
+ if "model.language_model" in key:
183
+ key = key.replace("model.language_model", "language_model.model")
184
+ if "lm_head" in key and not key.startswith("language_model"):
185
+ key = key.replace("lm_head", "language_model.lm_head")
186
+ return key
187
+
188
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,52 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str
10
+ hidden_size: int = 8192
11
+ head_dim: int = 128
12
+ num_hidden_layers: int = 40
13
+ intermediate_size: int = 14336
14
+ num_attention_heads: int = 64
15
+ num_key_value_heads: int = 8
16
+ rope_theta: float = 50000.0
17
+ vocab_size: int = 256000
18
+ layer_norm_eps: float = 1e-05
19
+ logit_scale: float = 0.0625
20
+ attention_bias: bool = False
21
+ layer_norm_bias: bool = False
22
+ sliding_window: int = 4096
23
+ sliding_window_pattern: int = 4
24
+ max_position_embeddings: int = 4096
25
+
26
+
27
+ @dataclass
28
+ class VisionConfig(BaseModelConfig):
29
+ model_type: str
30
+ hidden_size: int
31
+ num_attention_heads: int
32
+ patch_size: int
33
+ num_hidden_layers: int = 12
34
+ intermediate_size: int = 3072
35
+ image_size: int = 224
36
+ num_channels: int = 3
37
+ layer_norm_eps: float = 1e-6
38
+
39
+
40
+ @dataclass
41
+ class ModelConfig(BaseModelConfig):
42
+ text_config: TextConfig
43
+ vision_config: VisionConfig
44
+ model_type: str
45
+ image_token_index: int = 255036
46
+ max_splits_per_img: int = 12
47
+ downsample_factor: int = 2
48
+ alignment_intermediate_size: int = 28672
49
+ adapter_layer_norm_eps: float = 1e-06
50
+ vision_feature_layer: int = -1
51
+ vision_feature_select_strategy: str = "full"
52
+ eos_token_id: Optional[List[int]] = None
@@ -0,0 +1,202 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache, RotatingKVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config: TextConfig, layer_idx: int):
17
+ super().__init__()
18
+ self.config = config
19
+ self.layer_idx = layer_idx
20
+
21
+ dim = config.hidden_size
22
+ self.n_heads = n_heads = config.num_attention_heads
23
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
24
+ self.head_dim = head_dim = config.head_dim
25
+ if (head_dim * n_heads) != dim:
26
+ raise ValueError(
27
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
28
+ f" and `num_heads`: {n_heads})."
29
+ )
30
+ self.scale = head_dim**-0.5
31
+
32
+ attetion_bias = config.attention_bias
33
+
34
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
35
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
36
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
37
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
38
+
39
+ self.rope = nn.RoPE(head_dim, traditional=True, base=config.rope_theta)
40
+
41
+ self.use_sliding_window = (layer_idx + 1) % config.sliding_window_pattern != 0
42
+
43
+ def __call__(
44
+ self,
45
+ x: mx.array,
46
+ mask: Optional[mx.array] = None,
47
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
48
+ ) -> mx.array:
49
+ B, L, D = x.shape
50
+
51
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
52
+
53
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
54
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
55
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
56
+
57
+ # Apply RoPE only if sliding window is enabled
58
+ if self.use_sliding_window:
59
+ if cache is None:
60
+ queries = self.rope(queries)
61
+ keys = self.rope(keys)
62
+ else:
63
+ queries = self.rope(queries, offset=cache.offset)
64
+ keys = self.rope(keys, offset=cache.offset)
65
+
66
+ if cache is not None:
67
+ keys, values = cache.update_and_fetch(keys, values)
68
+
69
+ if self.use_sliding_window and mask is not None and isinstance(mask, mx.array):
70
+ key_len = keys.shape[-2]
71
+ if mask.shape[-1] != key_len:
72
+ mask = mask[..., -key_len:]
73
+
74
+ output = scaled_dot_product_attention(
75
+ queries, keys, values, cache, scale=self.scale, mask=mask
76
+ )
77
+
78
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
79
+ return self.o_proj(output)
80
+
81
+
82
+ class MLP(nn.Module):
83
+ def __init__(self, dim, hidden_dim):
84
+ super().__init__()
85
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
86
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
87
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
88
+
89
+ def __call__(self, x):
90
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
91
+
92
+
93
+ class TransformerBlock(nn.Module):
94
+ def __init__(self, config: TextConfig, layer_idx: int):
95
+ super().__init__()
96
+ self.hidden_size = config.hidden_size
97
+ self.n_heads = config.num_attention_heads
98
+
99
+ self.self_attn = Attention(config, layer_idx)
100
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
101
+ self.input_layernorm = nn.LayerNorm(
102
+ config.hidden_size, eps=config.layer_norm_eps, bias=config.layer_norm_bias
103
+ )
104
+ self.config = config
105
+
106
+ def __call__(
107
+ self,
108
+ x: mx.array,
109
+ mask: Optional[mx.array] = None,
110
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
111
+ ) -> mx.array:
112
+ h = self.input_layernorm(x)
113
+ attn_h = self.self_attn(h, mask, cache)
114
+ ff_h = self.mlp(h)
115
+ return attn_h + ff_h + x
116
+
117
+
118
+ class CohereModel(nn.Module):
119
+ def __init__(self, config: TextConfig):
120
+ super().__init__()
121
+ self.config = config
122
+ self.vocab_size = config.vocab_size
123
+ self.num_hidden_layers = config.num_hidden_layers
124
+ assert self.vocab_size > 0
125
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
126
+ self.layers = [
127
+ TransformerBlock(config, layer_idx=i)
128
+ for i in range(config.num_hidden_layers)
129
+ ]
130
+ self.norm = nn.LayerNorm(
131
+ config.hidden_size, eps=config.layer_norm_eps, bias=config.layer_norm_bias
132
+ )
133
+
134
+ def __call__(
135
+ self,
136
+ inputs: mx.array,
137
+ inputs_embeds: mx.array = None,
138
+ mask: mx.array = None,
139
+ cache=None,
140
+ ):
141
+ if inputs_embeds is None:
142
+ h = self.embed_tokens(inputs)
143
+ else:
144
+ h = inputs_embeds
145
+
146
+ if cache is None:
147
+ cache = [None] * len(self.layers)
148
+
149
+ if mask is None:
150
+ j = self.config.sliding_window_pattern
151
+ mask = create_attention_mask(h, cache[j - 1 : j])
152
+
153
+ for layer, c in zip(self.layers, cache):
154
+ h = layer(h, mask, c)
155
+
156
+ return self.norm(h)
157
+
158
+
159
+ class LanguageModel(nn.Module):
160
+ def __init__(self, config: TextConfig):
161
+ super().__init__()
162
+ self.model_type = config.model_type
163
+ self.model = CohereModel(config)
164
+ self.config = config
165
+
166
+ def __call__(
167
+ self,
168
+ inputs: mx.array,
169
+ inputs_embeds: mx.array = None,
170
+ mask: mx.array = None,
171
+ cache=None,
172
+ ):
173
+ out = self.model(inputs, inputs_embeds, mask, cache)
174
+ out = self.model.embed_tokens.as_linear(out)
175
+ out = out * self.model.config.logit_scale
176
+ return LanguageModelOutput(logits=out)
177
+
178
+ def make_cache(self):
179
+ caches = []
180
+ for i in range(self.config.num_hidden_layers):
181
+ if (
182
+ i % self.config.sliding_window_pattern
183
+ == self.config.sliding_window_pattern - 1
184
+ ):
185
+ caches.append(KVCache())
186
+ else:
187
+ caches.append(
188
+ RotatingKVCache(max_size=self.config.sliding_window, keep=0)
189
+ )
190
+ return caches
191
+
192
+ @property
193
+ def layers(self):
194
+ return self.model.layers
195
+
196
+ @property
197
+ def head_dim(self):
198
+ return self.model.config.head_dim
199
+
200
+ @property
201
+ def n_kv_heads(self):
202
+ return self.model.config.num_key_value_heads