mlx-ruby-lm 0.30.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.txt +21 -0
- data/README.md +83 -0
- data/exe/mlx_lm +7 -0
- data/lib/mlx_lm/benchmark.rb +67 -0
- data/lib/mlx_lm/chat_template.rb +41 -0
- data/lib/mlx_lm/cli.rb +113 -0
- data/lib/mlx_lm/config.rb +30 -0
- data/lib/mlx_lm/convert_utils.rb +51 -0
- data/lib/mlx_lm/generate.rb +204 -0
- data/lib/mlx_lm/load_utils.rb +87 -0
- data/lib/mlx_lm/model_args.rb +54 -0
- data/lib/mlx_lm/models/activations.rb +46 -0
- data/lib/mlx_lm/models/afm7.rb +131 -0
- data/lib/mlx_lm/models/afmoe.rb +421 -0
- data/lib/mlx_lm/models/apertus.rb +179 -0
- data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
- data/lib/mlx_lm/models/bailing_moe.rb +399 -0
- data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
- data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
- data/lib/mlx_lm/models/bitnet.rb +176 -0
- data/lib/mlx_lm/models/cache.rb +792 -0
- data/lib/mlx_lm/models/cohere.rb +150 -0
- data/lib/mlx_lm/models/cohere2.rb +224 -0
- data/lib/mlx_lm/models/dbrx.rb +286 -0
- data/lib/mlx_lm/models/deepseek.rb +239 -0
- data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
- data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
- data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
- data/lib/mlx_lm/models/dots1.rb +292 -0
- data/lib/mlx_lm/models/ernie4_5.rb +165 -0
- data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
- data/lib/mlx_lm/models/exaone.rb +169 -0
- data/lib/mlx_lm/models/exaone4.rb +233 -0
- data/lib/mlx_lm/models/exaone_moe.rb +421 -0
- data/lib/mlx_lm/models/falcon_h1.rb +102 -0
- data/lib/mlx_lm/models/gated_delta.rb +136 -0
- data/lib/mlx_lm/models/gemma.rb +159 -0
- data/lib/mlx_lm/models/gemma2.rb +198 -0
- data/lib/mlx_lm/models/gemma3.rb +85 -0
- data/lib/mlx_lm/models/gemma3_text.rb +270 -0
- data/lib/mlx_lm/models/gemma3n.rb +79 -0
- data/lib/mlx_lm/models/glm.rb +164 -0
- data/lib/mlx_lm/models/glm4.rb +180 -0
- data/lib/mlx_lm/models/glm4_moe.rb +343 -0
- data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
- data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
- data/lib/mlx_lm/models/gpt2.rb +166 -0
- data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
- data/lib/mlx_lm/models/gpt_neox.rb +178 -0
- data/lib/mlx_lm/models/gpt_oss.rb +319 -0
- data/lib/mlx_lm/models/granite.rb +170 -0
- data/lib/mlx_lm/models/granitemoe.rb +58 -0
- data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
- data/lib/mlx_lm/models/helium.rb +158 -0
- data/lib/mlx_lm/models/hunyuan.rb +378 -0
- data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
- data/lib/mlx_lm/models/internlm2.rb +160 -0
- data/lib/mlx_lm/models/internlm3.rb +237 -0
- data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
- data/lib/mlx_lm/models/jamba.rb +158 -0
- data/lib/mlx_lm/models/kimi_k25.rb +98 -0
- data/lib/mlx_lm/models/kimi_linear.rb +124 -0
- data/lib/mlx_lm/models/kimi_vl.rb +93 -0
- data/lib/mlx_lm/models/klear.rb +283 -0
- data/lib/mlx_lm/models/lfm2.rb +120 -0
- data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
- data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
- data/lib/mlx_lm/models/lille_130m.rb +148 -0
- data/lib/mlx_lm/models/llama.rb +183 -0
- data/lib/mlx_lm/models/llama4.rb +357 -0
- data/lib/mlx_lm/models/llama4_text.rb +195 -0
- data/lib/mlx_lm/models/longcat_flash.rb +153 -0
- data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
- data/lib/mlx_lm/models/mamba.rb +301 -0
- data/lib/mlx_lm/models/mamba2.rb +292 -0
- data/lib/mlx_lm/models/mimo.rb +174 -0
- data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
- data/lib/mlx_lm/models/minicpm.rb +169 -0
- data/lib/mlx_lm/models/minicpm3.rb +237 -0
- data/lib/mlx_lm/models/minimax.rb +282 -0
- data/lib/mlx_lm/models/ministral3.rb +304 -0
- data/lib/mlx_lm/models/mistral3.rb +84 -0
- data/lib/mlx_lm/models/mixtral.rb +192 -0
- data/lib/mlx_lm/models/mla.rb +75 -0
- data/lib/mlx_lm/models/nanochat.rb +167 -0
- data/lib/mlx_lm/models/nemotron.rb +202 -0
- data/lib/mlx_lm/models/nemotron_h.rb +212 -0
- data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
- data/lib/mlx_lm/models/olmo.rb +165 -0
- data/lib/mlx_lm/models/olmo2.rb +169 -0
- data/lib/mlx_lm/models/olmo3.rb +254 -0
- data/lib/mlx_lm/models/olmoe.rb +64 -0
- data/lib/mlx_lm/models/openelm.rb +208 -0
- data/lib/mlx_lm/models/phi.rb +156 -0
- data/lib/mlx_lm/models/phi3.rb +171 -0
- data/lib/mlx_lm/models/phi3small.rb +196 -0
- data/lib/mlx_lm/models/phimoe.rb +206 -0
- data/lib/mlx_lm/models/phixtral.rb +208 -0
- data/lib/mlx_lm/models/pipeline.rb +37 -0
- data/lib/mlx_lm/models/pixtral.rb +47 -0
- data/lib/mlx_lm/models/plamo.rb +169 -0
- data/lib/mlx_lm/models/plamo2.rb +173 -0
- data/lib/mlx_lm/models/qwen.rb +175 -0
- data/lib/mlx_lm/models/qwen2.rb +162 -0
- data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
- data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3.rb +167 -0
- data/lib/mlx_lm/models/qwen3_5.rb +69 -0
- data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
- data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
- data/lib/mlx_lm/models/qwen3_next.rb +147 -0
- data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
- data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
- data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
- data/lib/mlx_lm/models/rope_utils.rb +316 -0
- data/lib/mlx_lm/models/rwkv7.rb +101 -0
- data/lib/mlx_lm/models/seed_oss.rb +167 -0
- data/lib/mlx_lm/models/smollm3.rb +89 -0
- data/lib/mlx_lm/models/solar_open.rb +79 -0
- data/lib/mlx_lm/models/ssm.rb +162 -0
- data/lib/mlx_lm/models/stablelm.rb +160 -0
- data/lib/mlx_lm/models/starcoder2.rb +161 -0
- data/lib/mlx_lm/models/step3p5.rb +479 -0
- data/lib/mlx_lm/models/switch_layers.rb +221 -0
- data/lib/mlx_lm/models/telechat3.rb +192 -0
- data/lib/mlx_lm/models/youtu_llm.rb +230 -0
- data/lib/mlx_lm/models.rb +33 -0
- data/lib/mlx_lm/perplexity.rb +48 -0
- data/lib/mlx_lm/quantize.rb +131 -0
- data/lib/mlx_lm/sample_utils.rb +159 -0
- data/lib/mlx_lm/server.rb +190 -0
- data/lib/mlx_lm/tokenizer_utils.rb +158 -0
- data/lib/mlx_lm/tuner/lora.rb +165 -0
- data/lib/mlx_lm/version.rb +3 -0
- data/lib/mlx_lm/weight_utils.rb +170 -0
- data/lib/mlx_lm.rb +135 -0
- metadata +272 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
require "json"
|
|
2
|
+
|
|
3
|
+
module MlxLm
|
|
4
|
+
module LoadUtils
|
|
5
|
+
module_function
|
|
6
|
+
|
|
7
|
+
# Load a model and tokenizer from a local directory.
|
|
8
|
+
#
|
|
9
|
+
# @param model_path [String] Path to the model directory
|
|
10
|
+
# @param tokenizer_config [Hash] Additional tokenizer config overrides
|
|
11
|
+
# @return [Array(nn::Module, TokenizerWrapper)] The loaded model and tokenizer
|
|
12
|
+
def load(model_path, tokenizer_config: nil)
|
|
13
|
+
model, config = load_model(model_path)
|
|
14
|
+
tokenizer = load_tokenizer(model_path)
|
|
15
|
+
[model, tokenizer]
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Load model from a local directory containing config.json and safetensors.
|
|
19
|
+
#
|
|
20
|
+
# @param model_path [String] Path to the model directory
|
|
21
|
+
# @return [Array(nn::Module, Hash)] The loaded model and config
|
|
22
|
+
def load_model(model_path)
|
|
23
|
+
config = Config.load(model_path)
|
|
24
|
+
|
|
25
|
+
# Get model and args classes from registry
|
|
26
|
+
model_class, args_class = Models.get_classes(config)
|
|
27
|
+
|
|
28
|
+
# Instantiate model args from config
|
|
29
|
+
model_args = args_class.from_dict(config)
|
|
30
|
+
|
|
31
|
+
# Create model
|
|
32
|
+
model = model_class.new(model_args)
|
|
33
|
+
|
|
34
|
+
# Load weights
|
|
35
|
+
weights = WeightUtils.load_sharded_safetensors(model_path)
|
|
36
|
+
|
|
37
|
+
# Apply model-specific weight sanitization
|
|
38
|
+
if model.respond_to?(:sanitize)
|
|
39
|
+
weights = model.sanitize(weights)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Apply quantization if config specifies it
|
|
43
|
+
quantization = config["quantization"]
|
|
44
|
+
if quantization
|
|
45
|
+
group_size = quantization["group_size"] || 64
|
|
46
|
+
bits = quantization["bits"] || 4
|
|
47
|
+
Quantize.quantize_model(model, group_size: group_size, bits: bits, weights: weights)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Load weights into model
|
|
51
|
+
model.load_weights(weights, strict: false)
|
|
52
|
+
|
|
53
|
+
[model, config]
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
# Load tokenizer from a local directory.
|
|
57
|
+
#
|
|
58
|
+
# @param model_path [String] Path containing tokenizer files
|
|
59
|
+
# @return [TokenizerWrapper] The loaded tokenizer
|
|
60
|
+
def load_tokenizer(model_path)
|
|
61
|
+
tokenizer_path = File.join(model_path, "tokenizer.json")
|
|
62
|
+
raise "Tokenizer not found at #{tokenizer_path}" unless File.exist?(tokenizer_path)
|
|
63
|
+
|
|
64
|
+
tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
|
|
65
|
+
|
|
66
|
+
# Try to load tokenizer config for EOS token
|
|
67
|
+
config_path = File.join(model_path, "tokenizer_config.json")
|
|
68
|
+
eos_token = nil
|
|
69
|
+
eos_token_id = nil
|
|
70
|
+
if File.exist?(config_path)
|
|
71
|
+
tc = JSON.parse(File.read(config_path))
|
|
72
|
+
eos_token = tc["eos_token"]
|
|
73
|
+
eos_token = eos_token["content"] if eos_token.is_a?(Hash)
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# Try to get eos_token_id from config.json
|
|
77
|
+
model_config_path = File.join(model_path, "config.json")
|
|
78
|
+
if File.exist?(model_config_path)
|
|
79
|
+
mc = JSON.parse(File.read(model_config_path))
|
|
80
|
+
eos_token_id = mc["eos_token_id"]
|
|
81
|
+
eos_token_id = eos_token_id.is_a?(::Array) ? eos_token_id : [eos_token_id].compact
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
TokenizerWrapper.new(tokenizer, eos_token: eos_token, eos_token_id: eos_token_id)
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
# Base class for model configuration arguments.
|
|
3
|
+
# Mirrors Python's BaseModelArgs dataclass with from_dict filtering.
|
|
4
|
+
class BaseModelArgs
|
|
5
|
+
def self.fields
|
|
6
|
+
@fields ||= {}
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
def self.field(name, default: :__required__)
|
|
10
|
+
fields[name] = default
|
|
11
|
+
|
|
12
|
+
attr_accessor name
|
|
13
|
+
|
|
14
|
+
if default == :__required__
|
|
15
|
+
# no default
|
|
16
|
+
else
|
|
17
|
+
define_method(:"default_#{name}") { default }
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def self.inherited(subclass)
|
|
22
|
+
super
|
|
23
|
+
# Copy parent fields into subclass
|
|
24
|
+
subclass.instance_variable_set(:@fields, fields.dup)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def self.from_dict(params)
|
|
28
|
+
known = {}
|
|
29
|
+
fields.each do |name, default|
|
|
30
|
+
str_name = name.to_s
|
|
31
|
+
if params.key?(str_name)
|
|
32
|
+
known[name] = params[str_name]
|
|
33
|
+
elsif default != :__required__
|
|
34
|
+
known[name] = default
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
new(**known)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def initialize(**kwargs)
|
|
41
|
+
kwargs.each do |k, v|
|
|
42
|
+
if self.class.fields.key?(k)
|
|
43
|
+
instance_variable_set(:"@#{k}", v)
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
# Set defaults for any fields not provided
|
|
47
|
+
self.class.fields.each do |name, default|
|
|
48
|
+
next if kwargs.key?(name)
|
|
49
|
+
next if default == :__required__
|
|
50
|
+
instance_variable_set(:"@#{name}", default)
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
module MlxLm
|
|
2
|
+
module Models
|
|
3
|
+
module Activations
|
|
4
|
+
module_function
|
|
5
|
+
|
|
6
|
+
def swiglu(gate, x)
|
|
7
|
+
MLX::NN.silu(gate) * x
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def xielu(x, alpha_p, alpha_n, beta, eps)
|
|
11
|
+
mx = MLX::Core
|
|
12
|
+
alpha_p = MLX::NN.softplus(alpha_p)
|
|
13
|
+
alpha_n = beta + MLX::NN.softplus(alpha_n)
|
|
14
|
+
|
|
15
|
+
mx.where(
|
|
16
|
+
mx.greater(x, 0.0),
|
|
17
|
+
alpha_p * mx.square(x) + beta * x,
|
|
18
|
+
(mx.expm1(mx.minimum(x, eps)) - x) * alpha_n + beta * x
|
|
19
|
+
)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
class XieLU < MLX::NN::Module
|
|
23
|
+
def initialize(
|
|
24
|
+
alpha_p_init: 0.8,
|
|
25
|
+
alpha_n_init: 0.8,
|
|
26
|
+
beta: 0.5,
|
|
27
|
+
eps: -1e-6
|
|
28
|
+
)
|
|
29
|
+
super()
|
|
30
|
+
mx = MLX::Core
|
|
31
|
+
alpha_p_tensor = mx.array(alpha_p_init)
|
|
32
|
+
alpha_n_tensor = mx.array(alpha_n_init - beta)
|
|
33
|
+
|
|
34
|
+
self.alpha_p = mx.log(mx.exp(alpha_p_tensor) - 1.0)
|
|
35
|
+
self.alpha_n = mx.log(mx.exp(alpha_n_tensor) - 1.0)
|
|
36
|
+
self.beta = mx.array(beta)
|
|
37
|
+
self.eps = mx.array(eps)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def call(x)
|
|
41
|
+
Activations.xielu(x, alpha_p, alpha_n, beta, eps)
|
|
42
|
+
end
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
require_relative "afmoe"
|
|
2
|
+
|
|
3
|
+
module MlxLm
|
|
4
|
+
module Models
|
|
5
|
+
module Afm7
|
|
6
|
+
class ModelArgs < Afmoe::ModelArgs
|
|
7
|
+
field :model_type, default: "afm7"
|
|
8
|
+
field :hidden_dim, default: nil
|
|
9
|
+
field :num_layers, default: nil
|
|
10
|
+
field :num_kv_reuse_layers, default: 0
|
|
11
|
+
field :num_heads, default: nil
|
|
12
|
+
field :num_kv_heads, default: nil
|
|
13
|
+
field :hidden_dim_scale_factor, default: nil
|
|
14
|
+
|
|
15
|
+
def initialize(**kwargs)
|
|
16
|
+
afm7_style = _afm7_style_kwargs?(kwargs)
|
|
17
|
+
super
|
|
18
|
+
|
|
19
|
+
@hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !@hidden_dim.nil?
|
|
20
|
+
@num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !@num_layers.nil?
|
|
21
|
+
@num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !@num_heads.nil?
|
|
22
|
+
@num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !@num_kv_heads.nil?
|
|
23
|
+
|
|
24
|
+
if kwargs.key?(:hidden_dim_scale_factor) && !@hidden_dim_scale_factor.nil? && !@hidden_size.nil?
|
|
25
|
+
@intermediate_size = (@hidden_size * @hidden_dim_scale_factor.to_f).to_i
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
if !@hidden_size.nil? && !@num_attention_heads.nil? && @num_attention_heads.to_i > 0
|
|
29
|
+
@head_dim = @hidden_size / @num_attention_heads
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
if kwargs.key?(:num_kv_reuse_layers) && !@num_hidden_layers.nil?
|
|
33
|
+
@num_dense_layers = [@num_hidden_layers.to_i - @num_kv_reuse_layers.to_i, 0].max
|
|
34
|
+
elsif afm7_style && !@num_hidden_layers.nil?
|
|
35
|
+
@num_dense_layers = @num_hidden_layers
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
if afm7_style
|
|
39
|
+
@num_experts = 1 unless kwargs.key?(:num_experts)
|
|
40
|
+
@num_experts_per_tok = 1 unless kwargs.key?(:num_experts_per_tok)
|
|
41
|
+
@num_shared_experts = 0 unless kwargs.key?(:num_shared_experts)
|
|
42
|
+
@mup_enabled = false unless kwargs.key?(:mup_enabled)
|
|
43
|
+
@layer_types = Array.new(@num_hidden_layers) { "full_attention" } unless kwargs.key?(:layer_types)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
@num_key_value_heads ||= @num_attention_heads
|
|
47
|
+
@layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } unless @num_hidden_layers.nil?
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def to_afmoe_dict
|
|
51
|
+
{
|
|
52
|
+
"model_type" => @model_type,
|
|
53
|
+
"layer_types" => @layer_types,
|
|
54
|
+
"vocab_size" => @vocab_size,
|
|
55
|
+
"hidden_size" => @hidden_size,
|
|
56
|
+
"intermediate_size" => @intermediate_size,
|
|
57
|
+
"moe_intermediate_size" => @moe_intermediate_size,
|
|
58
|
+
"num_hidden_layers" => @num_hidden_layers,
|
|
59
|
+
"num_attention_heads" => @num_attention_heads,
|
|
60
|
+
"num_key_value_heads" => @num_key_value_heads,
|
|
61
|
+
"head_dim" => @head_dim,
|
|
62
|
+
"max_position_embeddings" => @max_position_embeddings,
|
|
63
|
+
"rms_norm_eps" => @rms_norm_eps,
|
|
64
|
+
"rope_theta" => @rope_theta,
|
|
65
|
+
"rope_scaling" => @rope_scaling,
|
|
66
|
+
"tie_word_embeddings" => @tie_word_embeddings,
|
|
67
|
+
"num_experts" => @num_experts,
|
|
68
|
+
"num_experts_per_tok" => @num_experts_per_tok,
|
|
69
|
+
"num_shared_experts" => @num_shared_experts,
|
|
70
|
+
"num_dense_layers" => @num_dense_layers,
|
|
71
|
+
"route_norm" => @route_norm,
|
|
72
|
+
"route_scale" => @route_scale,
|
|
73
|
+
"score_func" => @score_func,
|
|
74
|
+
"n_group" => @n_group,
|
|
75
|
+
"topk_group" => @topk_group,
|
|
76
|
+
"sliding_window" => @sliding_window,
|
|
77
|
+
"mup_enabled" => @mup_enabled,
|
|
78
|
+
}
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
private
|
|
82
|
+
|
|
83
|
+
def _afm7_style_kwargs?(kwargs)
|
|
84
|
+
kwargs.key?(:hidden_dim) ||
|
|
85
|
+
kwargs.key?(:num_layers) ||
|
|
86
|
+
kwargs.key?(:num_heads) ||
|
|
87
|
+
kwargs.key?(:num_kv_heads) ||
|
|
88
|
+
kwargs.key?(:num_kv_reuse_layers) ||
|
|
89
|
+
kwargs.key?(:hidden_dim_scale_factor)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
class Model < MLX::NN::Module
|
|
94
|
+
def initialize(args)
|
|
95
|
+
super()
|
|
96
|
+
@args = args
|
|
97
|
+
self.model_type = args.model_type
|
|
98
|
+
self.wrapped_model = Afmoe::Model.new(Afmoe::ModelArgs.from_dict(args.to_afmoe_dict))
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def call(inputs, cache: nil)
|
|
102
|
+
wrapped_model.call(inputs, cache: cache)
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def sanitize(weights)
|
|
106
|
+
wrapped_model.sanitize(weights)
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def layers
|
|
110
|
+
wrapped_model.layers
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def make_cache
|
|
114
|
+
return nil unless wrapped_model.respond_to?(:make_cache)
|
|
115
|
+
|
|
116
|
+
wrapped_model.make_cache
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def cast_predicate
|
|
120
|
+
wrapped_model.cast_predicate
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
def quant_predicate
|
|
124
|
+
wrapped_model.quant_predicate
|
|
125
|
+
end
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
Models.register("afm7", Model, ModelArgs)
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|