fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/lora.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import mlx.optimizers as optim
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from .prompt_utils import apply_chat_template
|
|
11
|
+
from .trainer import Dataset, Trainer, save_adapter
|
|
12
|
+
from .trainer.utils import apply_lora_layers, find_all_linear_names, get_peft_model
|
|
13
|
+
from .utils import load, load_image_processor
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def custom_print(*args, **kwargs):
|
|
20
|
+
tqdm.write(" ".join(map(str, args)), **kwargs)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def main(args):
|
|
24
|
+
logger.info(f"\033[32mLoading model from {args.model_path}\033[0m")
|
|
25
|
+
model, processor = load(
|
|
26
|
+
args.model_path, processor_config={"trust_remote_code": True}
|
|
27
|
+
)
|
|
28
|
+
config = model.config.__dict__
|
|
29
|
+
image_processor = load_image_processor(args.model_path)
|
|
30
|
+
|
|
31
|
+
logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
|
|
32
|
+
dataset = load_dataset(args.dataset, split=args.split)
|
|
33
|
+
|
|
34
|
+
if "messages" not in dataset.column_names:
|
|
35
|
+
raise ValueError("Dataset must have a 'messages' column")
|
|
36
|
+
if "images" not in dataset.column_names:
|
|
37
|
+
raise ValueError("Dataset must have an 'images' column")
|
|
38
|
+
|
|
39
|
+
if args.apply_chat_template:
|
|
40
|
+
logger.info(f"\033[32mApplying chat template to the dataset\033[0m")
|
|
41
|
+
|
|
42
|
+
def process_data(examples):
|
|
43
|
+
if config["model_type"] == "pixtral":
|
|
44
|
+
conversations = apply_chat_template(
|
|
45
|
+
config=config,
|
|
46
|
+
processor=processor,
|
|
47
|
+
prompt=examples["messages"],
|
|
48
|
+
return_messages=True,
|
|
49
|
+
)
|
|
50
|
+
examples["messages"] = [
|
|
51
|
+
json.dumps(item, ensure_ascii=False) for item in conversations
|
|
52
|
+
]
|
|
53
|
+
else:
|
|
54
|
+
examples["messages"] = apply_chat_template(
|
|
55
|
+
config=config,
|
|
56
|
+
processor=processor,
|
|
57
|
+
prompt=examples["messages"],
|
|
58
|
+
return_messages=True,
|
|
59
|
+
)
|
|
60
|
+
return examples
|
|
61
|
+
|
|
62
|
+
dataset = dataset.map(process_data)
|
|
63
|
+
|
|
64
|
+
dataset = Dataset(
|
|
65
|
+
dataset,
|
|
66
|
+
config,
|
|
67
|
+
processor,
|
|
68
|
+
image_processor=image_processor,
|
|
69
|
+
image_resize_shape=args.image_resize_shape,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
adapter_path = args.adapter_path
|
|
73
|
+
if adapter_path:
|
|
74
|
+
logger.info(f"\033[32mResuming from adapter path {adapter_path}\033[0m")
|
|
75
|
+
logger.info(
|
|
76
|
+
f"\033[32mLora rank, alpha, and dropout will be loaded from adapter_config.json file\033[0m"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
model = apply_lora_layers(model, adapter_path)
|
|
80
|
+
|
|
81
|
+
else:
|
|
82
|
+
logger.info(f"\033[32mSetting up LoRA\033[0m")
|
|
83
|
+
|
|
84
|
+
list_of_modules = find_all_linear_names(model.language_model)
|
|
85
|
+
model = get_peft_model(
|
|
86
|
+
model,
|
|
87
|
+
list_of_modules,
|
|
88
|
+
rank=args.lora_rank,
|
|
89
|
+
alpha=args.lora_alpha,
|
|
90
|
+
dropout=args.lora_dropout,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
logger.info(f"\033[32mSetting up optimizer\033[0m")
|
|
94
|
+
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
|
95
|
+
|
|
96
|
+
logger.info(f"\033[32mSetting up trainer\033[0m")
|
|
97
|
+
trainer = Trainer(model, optimizer)
|
|
98
|
+
|
|
99
|
+
model.train()
|
|
100
|
+
|
|
101
|
+
# Training loop
|
|
102
|
+
logger.info(f"\033[32mTraining model\033[0m")
|
|
103
|
+
for epoch in range(args.epochs):
|
|
104
|
+
if args.steps == 0:
|
|
105
|
+
args.steps = len(dataset) // args.batch_size
|
|
106
|
+
|
|
107
|
+
progress_bar = tqdm(range(args.steps), position=0, leave=True)
|
|
108
|
+
for i in progress_bar:
|
|
109
|
+
loss = trainer.train_step(
|
|
110
|
+
dataset[i * args.batch_size : (i + 1) * args.batch_size]
|
|
111
|
+
)
|
|
112
|
+
# Update progress bar
|
|
113
|
+
progress_bar.update(1)
|
|
114
|
+
progress_bar.set_postfix(
|
|
115
|
+
{"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"}
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if i % args.print_every == 0:
|
|
119
|
+
# Log additional information
|
|
120
|
+
custom_print(
|
|
121
|
+
{
|
|
122
|
+
"Epoch": epoch,
|
|
123
|
+
"Step": i,
|
|
124
|
+
"Loss": f"{loss.item():.4f}",
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
# Save the interim adapter after each epoch except the last.
|
|
128
|
+
if args.save_after_epoch and (epoch < (args.epochs - 1)):
|
|
129
|
+
head, tail = os.path.split(args.output_path)
|
|
130
|
+
save_adapter(model, head + os.sep + "epoch_" + str(epoch) + "_" + tail)
|
|
131
|
+
|
|
132
|
+
# Save the adapter
|
|
133
|
+
save_adapter(model, args.output_path)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == "__main__":
|
|
137
|
+
parser = argparse.ArgumentParser(description="Train NanoLLaVA model")
|
|
138
|
+
parser.add_argument(
|
|
139
|
+
"--model-path",
|
|
140
|
+
type=str,
|
|
141
|
+
default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
|
|
142
|
+
help="Path to the pre-trained model",
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument(
|
|
145
|
+
"--dataset", type=str, required=True, help="Path to the dataset"
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"--split", type=str, default="train", help="Split to use for training"
|
|
149
|
+
)
|
|
150
|
+
parser.add_argument(
|
|
151
|
+
"--image-resize-shape",
|
|
152
|
+
type=int,
|
|
153
|
+
nargs=2,
|
|
154
|
+
default=None,
|
|
155
|
+
help="Resize images to this shape",
|
|
156
|
+
)
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"--apply-chat-template",
|
|
159
|
+
action="store_false",
|
|
160
|
+
help="Apply chat template to the dataset",
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--learning-rate",
|
|
164
|
+
type=float,
|
|
165
|
+
default=1e-4,
|
|
166
|
+
help="Learning rate for the optimizer",
|
|
167
|
+
)
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"--batch-size", type=int, default=1, help="Batch size for training"
|
|
170
|
+
)
|
|
171
|
+
parser.add_argument(
|
|
172
|
+
"--epochs", type=int, default=1, help="Number of epochs to train"
|
|
173
|
+
)
|
|
174
|
+
parser.add_argument(
|
|
175
|
+
"--steps", type=int, default=0, help="Number of steps per epoch"
|
|
176
|
+
)
|
|
177
|
+
parser.add_argument(
|
|
178
|
+
"--print-every", type=int, default=10, help="Print loss every n steps"
|
|
179
|
+
)
|
|
180
|
+
parser.add_argument(
|
|
181
|
+
"--lora-alpha",
|
|
182
|
+
type=float,
|
|
183
|
+
default=0.1,
|
|
184
|
+
help="LoRA scaling factor (alpha / rank)",
|
|
185
|
+
)
|
|
186
|
+
parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank")
|
|
187
|
+
parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--output-path",
|
|
190
|
+
type=str,
|
|
191
|
+
default="adapters",
|
|
192
|
+
help="Path to save the trained adapter",
|
|
193
|
+
)
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--adapter-path",
|
|
196
|
+
type=str,
|
|
197
|
+
default=None,
|
|
198
|
+
help="Load path to resume training from a previously saved adapter",
|
|
199
|
+
)
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--save-after-epoch",
|
|
202
|
+
action="store_true",
|
|
203
|
+
help="Save interim versions of adapter files after each epoch",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
args = parser.parse_args()
|
|
207
|
+
main(args)
|
|
File without changes
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ..base import InputEmbeddingsFeatures
|
|
8
|
+
from .config import ModelConfig
|
|
9
|
+
from .language import LanguageModel
|
|
10
|
+
from .vision import VisionModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AyaVisionMultiModalProjector(nn.Module):
|
|
14
|
+
def __init__(self, config: ModelConfig):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.config = config
|
|
17
|
+
self.downsample_factor = config.downsample_factor
|
|
18
|
+
self.alignment_intermediate_size = getattr(
|
|
19
|
+
config, "alignment_intermediate_size", config.text_config.hidden_size
|
|
20
|
+
)
|
|
21
|
+
if config.model_type == "aya_vision":
|
|
22
|
+
self.layernorm = nn.LayerNorm(
|
|
23
|
+
config.vision_config.hidden_size * (config.downsample_factor**2),
|
|
24
|
+
eps=config.adapter_layer_norm_eps,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
self.linear_1 = nn.Linear(
|
|
28
|
+
config.vision_config.hidden_size * (config.downsample_factor**2),
|
|
29
|
+
self.alignment_intermediate_size,
|
|
30
|
+
bias=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.act = nn.SiLU() # SwiGLU uses SiLU activation
|
|
34
|
+
|
|
35
|
+
# For SwiGLU, project down to half size since we split intermediate dim
|
|
36
|
+
self.linear_2 = nn.Linear(
|
|
37
|
+
self.alignment_intermediate_size // 2,
|
|
38
|
+
config.text_config.hidden_size,
|
|
39
|
+
bias=True,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __call__(self, image_features):
|
|
43
|
+
image_features = self.pixel_shuffle(image_features)
|
|
44
|
+
if self.config.model_type == "aya_vision":
|
|
45
|
+
image_features = self.layernorm(image_features)
|
|
46
|
+
hidden_states = self.linear_1(image_features)
|
|
47
|
+
|
|
48
|
+
# Split along last dimension and apply SwiGLU
|
|
49
|
+
x, gate = mx.split(hidden_states, 2, axis=-1)
|
|
50
|
+
hidden_states = self.act(gate) * x
|
|
51
|
+
|
|
52
|
+
hidden_states = self.linear_2(hidden_states)
|
|
53
|
+
return hidden_states
|
|
54
|
+
|
|
55
|
+
def pixel_shuffle(self, image_features): # B, S, D
|
|
56
|
+
batch_size, seq_length, feature_dim = image_features.shape
|
|
57
|
+
height = width = int(seq_length**0.5)
|
|
58
|
+
image_features = image_features.reshape(
|
|
59
|
+
image_features.shape[0], width, height, -1
|
|
60
|
+
)
|
|
61
|
+
channels = image_features.shape[-1]
|
|
62
|
+
image_features = image_features.reshape(
|
|
63
|
+
batch_size,
|
|
64
|
+
width,
|
|
65
|
+
int(height / self.downsample_factor),
|
|
66
|
+
int(channels * self.downsample_factor),
|
|
67
|
+
)
|
|
68
|
+
image_features = image_features.transpose(0, 2, 1, 3)
|
|
69
|
+
image_features = image_features.reshape(
|
|
70
|
+
batch_size,
|
|
71
|
+
int(height / self.downsample_factor),
|
|
72
|
+
int(width / self.downsample_factor),
|
|
73
|
+
-1,
|
|
74
|
+
)
|
|
75
|
+
image_features = image_features.transpose(0, 2, 1, 3)
|
|
76
|
+
return image_features
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class Model(nn.Module):
|
|
80
|
+
def __init__(self, config: ModelConfig):
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.config = config
|
|
83
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
84
|
+
self.language_model = LanguageModel(config.text_config)
|
|
85
|
+
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
|
|
86
|
+
self.vision_feature_layer = config.vision_feature_layer
|
|
87
|
+
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
|
88
|
+
|
|
89
|
+
def get_input_embeddings(
|
|
90
|
+
self,
|
|
91
|
+
input_ids: Optional[mx.array] = None,
|
|
92
|
+
pixel_values: Optional[mx.array] = None,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
if pixel_values is None:
|
|
96
|
+
return InputEmbeddingsFeatures(
|
|
97
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Get the input embeddings from the language model
|
|
101
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
102
|
+
|
|
103
|
+
spatial_shapes = kwargs.get("spatial_shapes", None)
|
|
104
|
+
# Get the ouptut hidden states from the vision model
|
|
105
|
+
*_, hidden_states = self.vision_tower(
|
|
106
|
+
pixel_values.transpose(0, 2, 3, 1),
|
|
107
|
+
spatial_shapes=spatial_shapes,
|
|
108
|
+
output_hidden_states=True,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Select the hidden states from the desired layer
|
|
112
|
+
selected_image_feature = hidden_states[self.vision_feature_layer]
|
|
113
|
+
|
|
114
|
+
if self.vision_feature_select_strategy == "default":
|
|
115
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
|
116
|
+
elif self.vision_feature_select_strategy == "full":
|
|
117
|
+
selected_image_feature = selected_image_feature
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Unexpected feature selection strategy: "
|
|
121
|
+
f"{self.vision_feature_select_strategy}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Pass image features through the multi-modal projector
|
|
125
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
|
126
|
+
|
|
127
|
+
# Insert special image tokens in the input_ids
|
|
128
|
+
final_inputs_embeds = self._merge_input_ids_with_image_features(
|
|
129
|
+
image_features, inputs_embeds, input_ids
|
|
130
|
+
)
|
|
131
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
132
|
+
|
|
133
|
+
def _merge_input_ids_with_image_features(
|
|
134
|
+
self, image_features, inputs_embeds, input_ids
|
|
135
|
+
):
|
|
136
|
+
image_token_index = self.config.image_token_index
|
|
137
|
+
|
|
138
|
+
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
|
139
|
+
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
|
140
|
+
num_images, _, _, vision_hidden_size = image_features.shape
|
|
141
|
+
|
|
142
|
+
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)
|
|
143
|
+
|
|
144
|
+
# cast to the dtype of the input_embeds to support quantized models
|
|
145
|
+
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
|
|
146
|
+
inputs_embeds.dtype
|
|
147
|
+
)
|
|
148
|
+
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
|
|
149
|
+
return inputs_embeds
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def layers(self):
|
|
153
|
+
return self.language_model.model.layers
|
|
154
|
+
|
|
155
|
+
def __call__(
|
|
156
|
+
self,
|
|
157
|
+
input_ids: mx.array,
|
|
158
|
+
pixel_values: mx.array,
|
|
159
|
+
mask: mx.array,
|
|
160
|
+
cache=None,
|
|
161
|
+
**kwargs,
|
|
162
|
+
):
|
|
163
|
+
|
|
164
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
165
|
+
input_ids, pixel_values, **kwargs
|
|
166
|
+
)
|
|
167
|
+
logits = self.language_model(
|
|
168
|
+
input_ids,
|
|
169
|
+
cache=cache,
|
|
170
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
171
|
+
)
|
|
172
|
+
return logits
|
|
173
|
+
|
|
174
|
+
def sanitize(self, weights):
|
|
175
|
+
def transform_key(key):
|
|
176
|
+
if "model.vision_tower" in key:
|
|
177
|
+
key = key.replace("model.vision_tower", "vision_tower")
|
|
178
|
+
if "model.multi_modal_projector" in key:
|
|
179
|
+
key = key.replace(
|
|
180
|
+
"model.multi_modal_projector", "multi_modal_projector"
|
|
181
|
+
)
|
|
182
|
+
if "model.language_model" in key:
|
|
183
|
+
key = key.replace("model.language_model", "language_model.model")
|
|
184
|
+
if "lm_head" in key and not key.startswith("language_model"):
|
|
185
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
186
|
+
return key
|
|
187
|
+
|
|
188
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from ..base import BaseModelConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class TextConfig(BaseModelConfig):
|
|
9
|
+
model_type: str
|
|
10
|
+
hidden_size: int = 8192
|
|
11
|
+
head_dim: int = 128
|
|
12
|
+
num_hidden_layers: int = 40
|
|
13
|
+
intermediate_size: int = 14336
|
|
14
|
+
num_attention_heads: int = 64
|
|
15
|
+
num_key_value_heads: int = 8
|
|
16
|
+
rope_theta: float = 50000.0
|
|
17
|
+
vocab_size: int = 256000
|
|
18
|
+
layer_norm_eps: float = 1e-05
|
|
19
|
+
logit_scale: float = 0.0625
|
|
20
|
+
attention_bias: bool = False
|
|
21
|
+
layer_norm_bias: bool = False
|
|
22
|
+
sliding_window: int = 4096
|
|
23
|
+
sliding_window_pattern: int = 4
|
|
24
|
+
max_position_embeddings: int = 4096
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class VisionConfig(BaseModelConfig):
|
|
29
|
+
model_type: str
|
|
30
|
+
hidden_size: int
|
|
31
|
+
num_attention_heads: int
|
|
32
|
+
patch_size: int
|
|
33
|
+
num_hidden_layers: int = 12
|
|
34
|
+
intermediate_size: int = 3072
|
|
35
|
+
image_size: int = 224
|
|
36
|
+
num_channels: int = 3
|
|
37
|
+
layer_norm_eps: float = 1e-6
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ModelConfig(BaseModelConfig):
|
|
42
|
+
text_config: TextConfig
|
|
43
|
+
vision_config: VisionConfig
|
|
44
|
+
model_type: str
|
|
45
|
+
image_token_index: int = 255036
|
|
46
|
+
max_splits_per_img: int = 12
|
|
47
|
+
downsample_factor: int = 2
|
|
48
|
+
alignment_intermediate_size: int = 28672
|
|
49
|
+
adapter_layer_norm_eps: float = 1e-06
|
|
50
|
+
vision_feature_layer: int = -1
|
|
51
|
+
vision_feature_select_strategy: str = "full"
|
|
52
|
+
eos_token_id: Optional[List[int]] = None
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import (
|
|
7
|
+
LanguageModelOutput,
|
|
8
|
+
create_attention_mask,
|
|
9
|
+
scaled_dot_product_attention,
|
|
10
|
+
)
|
|
11
|
+
from ..cache import KVCache, RotatingKVCache
|
|
12
|
+
from .config import TextConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Attention(nn.Module):
|
|
16
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.config = config
|
|
19
|
+
self.layer_idx = layer_idx
|
|
20
|
+
|
|
21
|
+
dim = config.hidden_size
|
|
22
|
+
self.n_heads = n_heads = config.num_attention_heads
|
|
23
|
+
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
|
24
|
+
self.head_dim = head_dim = config.head_dim
|
|
25
|
+
if (head_dim * n_heads) != dim:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
|
|
28
|
+
f" and `num_heads`: {n_heads})."
|
|
29
|
+
)
|
|
30
|
+
self.scale = head_dim**-0.5
|
|
31
|
+
|
|
32
|
+
attetion_bias = config.attention_bias
|
|
33
|
+
|
|
34
|
+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
|
35
|
+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
|
36
|
+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
|
37
|
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
|
38
|
+
|
|
39
|
+
self.rope = nn.RoPE(head_dim, traditional=True, base=config.rope_theta)
|
|
40
|
+
|
|
41
|
+
self.use_sliding_window = (layer_idx + 1) % config.sliding_window_pattern != 0
|
|
42
|
+
|
|
43
|
+
def __call__(
|
|
44
|
+
self,
|
|
45
|
+
x: mx.array,
|
|
46
|
+
mask: Optional[mx.array] = None,
|
|
47
|
+
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
48
|
+
) -> mx.array:
|
|
49
|
+
B, L, D = x.shape
|
|
50
|
+
|
|
51
|
+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
52
|
+
|
|
53
|
+
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
54
|
+
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
55
|
+
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
56
|
+
|
|
57
|
+
# Apply RoPE only if sliding window is enabled
|
|
58
|
+
if self.use_sliding_window:
|
|
59
|
+
if cache is None:
|
|
60
|
+
queries = self.rope(queries)
|
|
61
|
+
keys = self.rope(keys)
|
|
62
|
+
else:
|
|
63
|
+
queries = self.rope(queries, offset=cache.offset)
|
|
64
|
+
keys = self.rope(keys, offset=cache.offset)
|
|
65
|
+
|
|
66
|
+
if cache is not None:
|
|
67
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
68
|
+
|
|
69
|
+
if self.use_sliding_window and mask is not None and isinstance(mask, mx.array):
|
|
70
|
+
key_len = keys.shape[-2]
|
|
71
|
+
if mask.shape[-1] != key_len:
|
|
72
|
+
mask = mask[..., -key_len:]
|
|
73
|
+
|
|
74
|
+
output = scaled_dot_product_attention(
|
|
75
|
+
queries, keys, values, cache, scale=self.scale, mask=mask
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
79
|
+
return self.o_proj(output)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class MLP(nn.Module):
|
|
83
|
+
def __init__(self, dim, hidden_dim):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
86
|
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
87
|
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
88
|
+
|
|
89
|
+
def __call__(self, x):
|
|
90
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TransformerBlock(nn.Module):
|
|
94
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
95
|
+
super().__init__()
|
|
96
|
+
self.hidden_size = config.hidden_size
|
|
97
|
+
self.n_heads = config.num_attention_heads
|
|
98
|
+
|
|
99
|
+
self.self_attn = Attention(config, layer_idx)
|
|
100
|
+
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
|
101
|
+
self.input_layernorm = nn.LayerNorm(
|
|
102
|
+
config.hidden_size, eps=config.layer_norm_eps, bias=config.layer_norm_bias
|
|
103
|
+
)
|
|
104
|
+
self.config = config
|
|
105
|
+
|
|
106
|
+
def __call__(
|
|
107
|
+
self,
|
|
108
|
+
x: mx.array,
|
|
109
|
+
mask: Optional[mx.array] = None,
|
|
110
|
+
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
111
|
+
) -> mx.array:
|
|
112
|
+
h = self.input_layernorm(x)
|
|
113
|
+
attn_h = self.self_attn(h, mask, cache)
|
|
114
|
+
ff_h = self.mlp(h)
|
|
115
|
+
return attn_h + ff_h + x
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class CohereModel(nn.Module):
|
|
119
|
+
def __init__(self, config: TextConfig):
|
|
120
|
+
super().__init__()
|
|
121
|
+
self.config = config
|
|
122
|
+
self.vocab_size = config.vocab_size
|
|
123
|
+
self.num_hidden_layers = config.num_hidden_layers
|
|
124
|
+
assert self.vocab_size > 0
|
|
125
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
126
|
+
self.layers = [
|
|
127
|
+
TransformerBlock(config, layer_idx=i)
|
|
128
|
+
for i in range(config.num_hidden_layers)
|
|
129
|
+
]
|
|
130
|
+
self.norm = nn.LayerNorm(
|
|
131
|
+
config.hidden_size, eps=config.layer_norm_eps, bias=config.layer_norm_bias
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def __call__(
|
|
135
|
+
self,
|
|
136
|
+
inputs: mx.array,
|
|
137
|
+
inputs_embeds: mx.array = None,
|
|
138
|
+
mask: mx.array = None,
|
|
139
|
+
cache=None,
|
|
140
|
+
):
|
|
141
|
+
if inputs_embeds is None:
|
|
142
|
+
h = self.embed_tokens(inputs)
|
|
143
|
+
else:
|
|
144
|
+
h = inputs_embeds
|
|
145
|
+
|
|
146
|
+
if cache is None:
|
|
147
|
+
cache = [None] * len(self.layers)
|
|
148
|
+
|
|
149
|
+
if mask is None:
|
|
150
|
+
j = self.config.sliding_window_pattern
|
|
151
|
+
mask = create_attention_mask(h, cache[j - 1 : j])
|
|
152
|
+
|
|
153
|
+
for layer, c in zip(self.layers, cache):
|
|
154
|
+
h = layer(h, mask, c)
|
|
155
|
+
|
|
156
|
+
return self.norm(h)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class LanguageModel(nn.Module):
|
|
160
|
+
def __init__(self, config: TextConfig):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.model_type = config.model_type
|
|
163
|
+
self.model = CohereModel(config)
|
|
164
|
+
self.config = config
|
|
165
|
+
|
|
166
|
+
def __call__(
|
|
167
|
+
self,
|
|
168
|
+
inputs: mx.array,
|
|
169
|
+
inputs_embeds: mx.array = None,
|
|
170
|
+
mask: mx.array = None,
|
|
171
|
+
cache=None,
|
|
172
|
+
):
|
|
173
|
+
out = self.model(inputs, inputs_embeds, mask, cache)
|
|
174
|
+
out = self.model.embed_tokens.as_linear(out)
|
|
175
|
+
out = out * self.model.config.logit_scale
|
|
176
|
+
return LanguageModelOutput(logits=out)
|
|
177
|
+
|
|
178
|
+
def make_cache(self):
|
|
179
|
+
caches = []
|
|
180
|
+
for i in range(self.config.num_hidden_layers):
|
|
181
|
+
if (
|
|
182
|
+
i % self.config.sliding_window_pattern
|
|
183
|
+
== self.config.sliding_window_pattern - 1
|
|
184
|
+
):
|
|
185
|
+
caches.append(KVCache())
|
|
186
|
+
else:
|
|
187
|
+
caches.append(
|
|
188
|
+
RotatingKVCache(max_size=self.config.sliding_window, keep=0)
|
|
189
|
+
)
|
|
190
|
+
return caches
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def layers(self):
|
|
194
|
+
return self.model.layers
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def head_dim(self):
|
|
198
|
+
return self.model.config.head_dim
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def n_kv_heads(self):
|
|
202
|
+
return self.model.config.num_key_value_heads
|