mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
import mlx.optimizers as opt
|
|
8
|
+
from mlx.utils import tree_flatten, tree_unflatten
|
|
9
|
+
|
|
10
|
+
EMBEDDING_LAYER_NAMES = {
|
|
11
|
+
"gemma3_text": ["embed_tokens"],
|
|
12
|
+
"lfm2": ["embed_tokens"],
|
|
13
|
+
"qwen3": ["embed_tokens"],
|
|
14
|
+
"t5gemma_encoder": ["embed_tokens"],
|
|
15
|
+
"modernbert": ["embeddings"]
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
def build_schedule(schedule_config: Dict):
|
|
19
|
+
"""
|
|
20
|
+
Build a learning rate schedule from the given config.
|
|
21
|
+
Adapted for MLX Optimizers.
|
|
22
|
+
"""
|
|
23
|
+
name = schedule_config["name"]
|
|
24
|
+
arguments = schedule_config["arguments"]
|
|
25
|
+
initial_lr = arguments[0]
|
|
26
|
+
|
|
27
|
+
# Create the main schedule function
|
|
28
|
+
if name == "constant":
|
|
29
|
+
# Create a callable that ignores the step and returns the LR
|
|
30
|
+
bound_schedule_fn = lambda _: initial_lr
|
|
31
|
+
else:
|
|
32
|
+
# For cosine_decay, linear_schedule, etc.
|
|
33
|
+
schedule_fn = getattr(opt, name)
|
|
34
|
+
bound_schedule_fn = schedule_fn(*arguments)
|
|
35
|
+
|
|
36
|
+
# Check for warmup
|
|
37
|
+
warmup_steps = schedule_config.get("warmup_steps", 0)
|
|
38
|
+
|
|
39
|
+
if warmup_steps > 0:
|
|
40
|
+
warmup_init = schedule_config.get("warmup_init", 0.0)
|
|
41
|
+
|
|
42
|
+
# Linear warmup: 0 -> initial_lr
|
|
43
|
+
warmup_fn = opt.linear_schedule(
|
|
44
|
+
warmup_init, initial_lr, warmup_steps
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return opt.join_schedules(
|
|
48
|
+
[warmup_fn, bound_schedule_fn], [warmup_steps]
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
return bound_schedule_fn
|
|
52
|
+
|
|
53
|
+
### We don't currently use LoRA or DoRA in MLX Raclate tuner
|
|
54
|
+
### Kept here for reference, and if we need to re-implement it later
|
|
55
|
+
# def linear_to_lora_layers(
|
|
56
|
+
# model: nn.Module,
|
|
57
|
+
# num_layers: int,
|
|
58
|
+
# config: Dict,
|
|
59
|
+
# use_dora: bool = False,
|
|
60
|
+
# ):
|
|
61
|
+
# """
|
|
62
|
+
# Convert some of the models linear layers to lora layers.
|
|
63
|
+
|
|
64
|
+
# Args:
|
|
65
|
+
# model (nn.Module): The neural network model.
|
|
66
|
+
# num_layers (int): The number of blocks to convert to lora layers
|
|
67
|
+
# starting from the last layer.
|
|
68
|
+
# config (dict): More configuration parameters for LoRA, including the
|
|
69
|
+
# rank, scale, and optional layer keys.
|
|
70
|
+
# use_dora (bool): If True, uses DoRA instead of LoRA.
|
|
71
|
+
# Default: ``False``
|
|
72
|
+
# """
|
|
73
|
+
# if num_layers > len(model.layers):
|
|
74
|
+
# raise ValueError(
|
|
75
|
+
# f"Requested {num_layers} LoRA layers "
|
|
76
|
+
# f"but the model only has {len(model.layers)} layers."
|
|
77
|
+
# )
|
|
78
|
+
|
|
79
|
+
# def to_lora(layer):
|
|
80
|
+
# if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
|
81
|
+
# LoRALayer = DoRALinear if use_dora else LoRALinear
|
|
82
|
+
# elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
|
|
83
|
+
# if use_dora:
|
|
84
|
+
# raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
|
|
85
|
+
# LoRALayer = LoRASwitchLinear
|
|
86
|
+
# elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
|
|
87
|
+
# LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
|
|
88
|
+
# else:
|
|
89
|
+
# raise ValueError(
|
|
90
|
+
# f"Can't convert layer of type {type(layer).__name__} to LoRA"
|
|
91
|
+
# )
|
|
92
|
+
|
|
93
|
+
# return LoRALayer.from_base(
|
|
94
|
+
# layer,
|
|
95
|
+
# r=config["rank"],
|
|
96
|
+
# scale=config["scale"],
|
|
97
|
+
# dropout=config["dropout"],
|
|
98
|
+
# )
|
|
99
|
+
|
|
100
|
+
# keys = config.get("keys", None)
|
|
101
|
+
# if keys is not None:
|
|
102
|
+
# keys = set(keys)
|
|
103
|
+
# elif model.model_type in [
|
|
104
|
+
# "mistral",
|
|
105
|
+
# "llama",
|
|
106
|
+
# "phi",
|
|
107
|
+
# "mixtral",
|
|
108
|
+
# "nemotron",
|
|
109
|
+
# "stablelm",
|
|
110
|
+
# "qwen2",
|
|
111
|
+
# "qwen2_moe",
|
|
112
|
+
# "phimoe",
|
|
113
|
+
# "gemma",
|
|
114
|
+
# "gemma2",
|
|
115
|
+
# "starcoder2",
|
|
116
|
+
# "cohere",
|
|
117
|
+
# "cohere2",
|
|
118
|
+
# "minicpm",
|
|
119
|
+
# "deepseek",
|
|
120
|
+
# "olmo2",
|
|
121
|
+
# ]:
|
|
122
|
+
# keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
|
123
|
+
# if model.model_type in ["mixtral", "phimoe"]:
|
|
124
|
+
# keys.add("block_sparse_moe.gate")
|
|
125
|
+
# if model.model_type == "qwen2_moe":
|
|
126
|
+
# keys.add("mlp.gate")
|
|
127
|
+
# keys.add("mlp.shared_expert_gate")
|
|
128
|
+
|
|
129
|
+
# elif model.model_type == "gpt_bigcode":
|
|
130
|
+
# keys = set(["attn.c_attn"])
|
|
131
|
+
# elif model.model_type == "gpt2":
|
|
132
|
+
# keys = set(["attn.c_attn"])
|
|
133
|
+
# elif model.model_type == "gpt_neox":
|
|
134
|
+
# keys = set(["attention.query_key_value"])
|
|
135
|
+
# elif model.model_type == "olmo":
|
|
136
|
+
# keys = set(["att_proj"])
|
|
137
|
+
# elif model.model_type == "openelm":
|
|
138
|
+
# keys = set(["attn.qkv_proj"])
|
|
139
|
+
# elif model.model_type == "phi3":
|
|
140
|
+
# keys = set(["self_attn.qkv_proj"])
|
|
141
|
+
# elif model.model_type == "phi-msft":
|
|
142
|
+
# keys = set(["mixer.Wqkv", "moe.gate"])
|
|
143
|
+
# elif model.model_type == "dbrx":
|
|
144
|
+
# keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
|
|
145
|
+
# elif model.model_type == "internlm2":
|
|
146
|
+
# keys = set(["attention.wqkv", "attention.wo"])
|
|
147
|
+
# elif model.model_type == "deepseek_v2":
|
|
148
|
+
# keys = set(
|
|
149
|
+
# [
|
|
150
|
+
# "self_attn.q_proj",
|
|
151
|
+
# "self_attn.q_a_proj",
|
|
152
|
+
# "self_attn.q_b_proj",
|
|
153
|
+
# "self_attn.kv_a_proj_with_mqa",
|
|
154
|
+
# "self_attn.kv_b_proj",
|
|
155
|
+
# ]
|
|
156
|
+
# )
|
|
157
|
+
# elif model.model_type == "mamba":
|
|
158
|
+
# keys = set(
|
|
159
|
+
# [
|
|
160
|
+
# "mixer.in_proj",
|
|
161
|
+
# "mixer.x_proj",
|
|
162
|
+
# "mixer.dt_proj",
|
|
163
|
+
# "mixer.out_proj",
|
|
164
|
+
# ]
|
|
165
|
+
# )
|
|
166
|
+
# elif model.model_type == "exaone":
|
|
167
|
+
# keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
|
|
168
|
+
# else:
|
|
169
|
+
# raise ValueError(f"Lora does not support {model.model_type}")
|
|
170
|
+
|
|
171
|
+
# for l in model.layers[-min(num_layers, 0) :]:
|
|
172
|
+
# lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
|
173
|
+
# if lora_layers:
|
|
174
|
+
# l.update_modules(tree_unflatten(lora_layers))
|
|
175
|
+
|
|
176
|
+
# lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
|
|
177
|
+
# if lora_modules:
|
|
178
|
+
# model.update_modules(tree_unflatten(lora_modules))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
|
182
|
+
# """
|
|
183
|
+
# Load any fine-tuned adapters / layers.
|
|
184
|
+
|
|
185
|
+
# Args:
|
|
186
|
+
# model (nn.Module): The neural network model.
|
|
187
|
+
# adapter_path (str): Path to the adapter configuration file.
|
|
188
|
+
|
|
189
|
+
# Returns:
|
|
190
|
+
# nn.Module: The updated model with LoRA layers applied.
|
|
191
|
+
# """
|
|
192
|
+
# adapter_path = Path(adapter_path)
|
|
193
|
+
# if not adapter_path.exists():
|
|
194
|
+
# raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
|
195
|
+
# with open(adapter_path / "adapter_config.json", "r") as fid:
|
|
196
|
+
# config = types.SimpleNamespace(**json.load(fid))
|
|
197
|
+
# fine_tune_type = getattr(config, "fine_tune_type", "lora")
|
|
198
|
+
# if fine_tune_type != "full":
|
|
199
|
+
# linear_to_lora_layers(
|
|
200
|
+
# model,
|
|
201
|
+
# config.num_layers,
|
|
202
|
+
# config.lora_parameters,
|
|
203
|
+
# use_dora=(fine_tune_type == "dora"),
|
|
204
|
+
# )
|
|
205
|
+
# model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
|
206
|
+
# return model
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def dequantize(model: nn.Module) -> nn.Module:
|
|
210
|
+
"""
|
|
211
|
+
Dequantize the quantized linear layers in the model.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
model (nn.Module): The model with quantized linear layers.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
nn.Module: The model with dequantized layers.
|
|
218
|
+
"""
|
|
219
|
+
de_quantize_layers = []
|
|
220
|
+
for name, module in model.named_modules():
|
|
221
|
+
if isinstance(module, nn.QuantizedLinear):
|
|
222
|
+
bias = "bias" in module
|
|
223
|
+
weight = module.weight
|
|
224
|
+
weight = mx.dequantize(
|
|
225
|
+
weight,
|
|
226
|
+
module.scales,
|
|
227
|
+
module.biases,
|
|
228
|
+
module.group_size,
|
|
229
|
+
module.bits,
|
|
230
|
+
).astype(mx.float16)
|
|
231
|
+
output_dims, input_dims = weight.shape
|
|
232
|
+
linear = nn.Linear(input_dims, output_dims, bias=bias)
|
|
233
|
+
linear.weight = weight
|
|
234
|
+
if bias:
|
|
235
|
+
linear.bias = module.bias
|
|
236
|
+
de_quantize_layers.append((name, linear))
|
|
237
|
+
if isinstance(module, nn.QuantizedEmbedding):
|
|
238
|
+
weight = mx.dequantize(
|
|
239
|
+
module.weight,
|
|
240
|
+
module.scales,
|
|
241
|
+
module.biases,
|
|
242
|
+
module.group_size,
|
|
243
|
+
module.bits,
|
|
244
|
+
).astype(mx.float16)
|
|
245
|
+
num_embeddings, dims = weight.shape
|
|
246
|
+
emb = nn.Embedding(num_embeddings, dims)
|
|
247
|
+
emb.weight = weight
|
|
248
|
+
de_quantize_layers.append((name, emb))
|
|
249
|
+
|
|
250
|
+
if len(de_quantize_layers) > 0:
|
|
251
|
+
model.update_modules(tree_unflatten(de_quantize_layers))
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# def remove_lora_layers(model: nn.Module) -> nn.Module:
|
|
256
|
+
# """
|
|
257
|
+
# Remove the LoRA layers from the model.
|
|
258
|
+
|
|
259
|
+
# Args:
|
|
260
|
+
# model (nn.Module): The model with LoRA layers.
|
|
261
|
+
|
|
262
|
+
# Returns:
|
|
263
|
+
# nn.Module: The model without LoRA layers.
|
|
264
|
+
# """
|
|
265
|
+
# reset_layers = []
|
|
266
|
+
# for name, module in model.named_modules():
|
|
267
|
+
# if isinstance(module, LoRALinear):
|
|
268
|
+
# reset_layers.append((name, module.linear))
|
|
269
|
+
# if len(reset_layers) > 0:
|
|
270
|
+
# model.update_modules(tree_unflatten(reset_layers))
|
|
271
|
+
# return model
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def nparams(module):
|
|
275
|
+
if hasattr(module, "bits"):
|
|
276
|
+
n = 0 if not hasattr(module, "bias") else module.bias.size
|
|
277
|
+
return n + module.weight.size * 32 // module.bits
|
|
278
|
+
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def print_trainable_parameters(model):
|
|
282
|
+
leaf_modules = tree_flatten(
|
|
283
|
+
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
|
284
|
+
)
|
|
285
|
+
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
|
286
|
+
trainable_p = (
|
|
287
|
+
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
|
288
|
+
)
|
|
289
|
+
print(
|
|
290
|
+
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
|
291
|
+
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
|
292
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
# server.py
|
|
2
|
+
from fastapi import FastAPI, HTTPException
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
from typing import List, Optional, Any, Union, Dict
|
|
5
|
+
import uvicorn
|
|
6
|
+
import gc
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
from mlx_raclate.utils.utils import PIPELINES, load
|
|
9
|
+
|
|
10
|
+
app = FastAPI(
|
|
11
|
+
title="Raclate Inference API",
|
|
12
|
+
description="API for using Raclate pipelines (ModernBERT, LFM2, Qwen, etc.) on Apple Silicon",
|
|
13
|
+
version="0.1.0"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# TODO:
|
|
17
|
+
# Separate Services: For complete isolation, run each pipeline type as a separate FastAPI service and use a lightweight API gateway to route requests.
|
|
18
|
+
# Worker Pool Architecture: Implement a worker pool where each worker specializes in a specific pipeline, and a dispatcher routes requests to the appropriate worker.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
model_cache = {}
|
|
22
|
+
|
|
23
|
+
def get_model(model_name: str, pipeline_name: str, config_file: Optional[Dict] = None):
|
|
24
|
+
"""
|
|
25
|
+
Factory function to get or create the appropriate model.
|
|
26
|
+
Checks cache based on name, pipeline, AND configuration.
|
|
27
|
+
"""
|
|
28
|
+
global model_cache
|
|
29
|
+
|
|
30
|
+
# Create a cache key string that includes config to differentiate
|
|
31
|
+
# e.g. LFM2 with late_interaction=True vs False
|
|
32
|
+
config_key = str(sorted(config_file.items())) if config_file else "default"
|
|
33
|
+
|
|
34
|
+
current_key = f"{model_name}_{pipeline_name}_{config_key}"
|
|
35
|
+
cached_key = model_cache.get("key", None)
|
|
36
|
+
|
|
37
|
+
if cached_key == current_key:
|
|
38
|
+
return model_cache
|
|
39
|
+
|
|
40
|
+
# Garbage collection before loading new model
|
|
41
|
+
if model_cache:
|
|
42
|
+
print(f"Unloading previous model: {model_cache.get('model_name')}")
|
|
43
|
+
model_cache = {}
|
|
44
|
+
mx.eval() # Ensure evaluation of any pending ops
|
|
45
|
+
gc.collect()
|
|
46
|
+
mx.metal.clear_cache()
|
|
47
|
+
|
|
48
|
+
print(f"Loading model: {model_name} | Pipeline: {pipeline_name} | Config: {config_file}")
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
model, tokenizer = load(
|
|
52
|
+
model_name,
|
|
53
|
+
pipeline=pipeline_name,
|
|
54
|
+
model_config=config_file
|
|
55
|
+
)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
|
58
|
+
|
|
59
|
+
model_cache = {
|
|
60
|
+
"key": current_key,
|
|
61
|
+
"model_name": model_name,
|
|
62
|
+
"pipeline": pipeline_name,
|
|
63
|
+
"model": model,
|
|
64
|
+
"tokenizer": tokenizer,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return model_cache
|
|
68
|
+
|
|
69
|
+
# -----------------------------------------------------------------------------
|
|
70
|
+
# Pydantic Models
|
|
71
|
+
# -----------------------------------------------------------------------------
|
|
72
|
+
|
|
73
|
+
class PredictionRequest(BaseModel):
|
|
74
|
+
model: str
|
|
75
|
+
pipeline: str
|
|
76
|
+
text: Union[str, List[str]]
|
|
77
|
+
|
|
78
|
+
# Optional parameters depending on pipeline
|
|
79
|
+
text_pair: Optional[Union[str, List[str]]] = Field(None, description="Secondary text for sequence classification pairs (e.g. NLI)")
|
|
80
|
+
reference_text: Optional[Union[str, List[str]]] = Field(None, description="Documents/References for similarity search")
|
|
81
|
+
|
|
82
|
+
# Configuration
|
|
83
|
+
config_file: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Configuration overrides (e.g., {'use_late_interaction': True})")
|
|
84
|
+
label_candidates: Optional[Union[Dict[str, str], List[str]]] = Field(None, description="For zero-shot-classification only")
|
|
85
|
+
|
|
86
|
+
# -----------------------------------------------------------------------------
|
|
87
|
+
# Endpoints
|
|
88
|
+
# -----------------------------------------------------------------------------
|
|
89
|
+
|
|
90
|
+
@app.post("/predict")
|
|
91
|
+
async def predict(request: PredictionRequest):
|
|
92
|
+
"""
|
|
93
|
+
Main inference endpoint handling multiple pipelines.
|
|
94
|
+
"""
|
|
95
|
+
if request.pipeline not in PIPELINES:
|
|
96
|
+
raise HTTPException(status_code=400, detail=f"Pipeline '{request.pipeline}' not supported. Available: {PIPELINES}")
|
|
97
|
+
|
|
98
|
+
# Standardize inputs to lists
|
|
99
|
+
texts = request.text if isinstance(request.text, list) else [request.text]
|
|
100
|
+
|
|
101
|
+
if len(texts) > 32:
|
|
102
|
+
raise HTTPException(status_code=400, detail="Batch size should not exceed 32 to protect memory")
|
|
103
|
+
|
|
104
|
+
# Load Model
|
|
105
|
+
model_info = get_model(request.model, request.pipeline, request.config_file)
|
|
106
|
+
tokenizer = model_info["tokenizer"]
|
|
107
|
+
model = model_info["model"]
|
|
108
|
+
|
|
109
|
+
# Determine generic args
|
|
110
|
+
max_len = getattr(model.config, "max_position_embeddings", 512)
|
|
111
|
+
result = {}
|
|
112
|
+
|
|
113
|
+
# -------------------------------------------------------------------------
|
|
114
|
+
# Pipeline: Text Classification (Sentiment, NLI, Regression)
|
|
115
|
+
# -------------------------------------------------------------------------
|
|
116
|
+
if request.pipeline == "text-classification":
|
|
117
|
+
text_pairs = None
|
|
118
|
+
if request.text_pair:
|
|
119
|
+
text_pairs = request.text_pair if isinstance(request.text_pair, list) else [request.text_pair]
|
|
120
|
+
if len(text_pairs) != len(texts):
|
|
121
|
+
raise HTTPException(status_code=400, detail="Length of text and text_pair must match")
|
|
122
|
+
|
|
123
|
+
inputs = tokenizer._tokenizer(
|
|
124
|
+
texts,
|
|
125
|
+
text_pairs,
|
|
126
|
+
return_tensors="mlx",
|
|
127
|
+
padding=True,
|
|
128
|
+
truncation=True,
|
|
129
|
+
max_length=max_len
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
outputs = model(
|
|
133
|
+
input_ids=inputs['input_ids'],
|
|
134
|
+
attention_mask=inputs['attention_mask'],
|
|
135
|
+
return_dict=True
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
probs = outputs["probabilities"] # Shape: [batch, num_labels]
|
|
139
|
+
|
|
140
|
+
# Format output
|
|
141
|
+
batch_results = []
|
|
142
|
+
id2label = getattr(model.config, "id2label", None)
|
|
143
|
+
|
|
144
|
+
# Convert to python list structure
|
|
145
|
+
probs_list = probs.tolist()
|
|
146
|
+
|
|
147
|
+
for i, row in enumerate(probs_list):
|
|
148
|
+
if id2label:
|
|
149
|
+
# Return dictionary mapping label -> score
|
|
150
|
+
item_res = [[id2label[str(j)], score] for j, score in enumerate(row)]
|
|
151
|
+
# Sort by score descending
|
|
152
|
+
item_res = sorted(item_res, key=lambda x: x[1], reverse=True)
|
|
153
|
+
batch_results.append(item_res)
|
|
154
|
+
else:
|
|
155
|
+
# Just return raw scores (e.g. regression or missing config)
|
|
156
|
+
batch_results.append(row)
|
|
157
|
+
|
|
158
|
+
result = {"predictions": batch_results}
|
|
159
|
+
|
|
160
|
+
# -------------------------------------------------------------------------
|
|
161
|
+
# Pipeline: Sentence Similarity (Dense & Late Interaction)
|
|
162
|
+
# -------------------------------------------------------------------------
|
|
163
|
+
elif request.pipeline in ["sentence-similarity", "sentence-transformers"]:
|
|
164
|
+
if not request.reference_text:
|
|
165
|
+
raise HTTPException(status_code=400, detail="reference_text is required for sentence-similarity")
|
|
166
|
+
|
|
167
|
+
refs = request.reference_text if isinstance(request.reference_text, list) else [request.reference_text]
|
|
168
|
+
|
|
169
|
+
q_inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
|
|
170
|
+
d_inputs = tokenizer._tokenizer(refs, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
|
|
171
|
+
|
|
172
|
+
# The model handles the complexity (Cosine vs MaxSim) internally based on config
|
|
173
|
+
outputs = model(
|
|
174
|
+
input_ids=q_inputs['input_ids'],
|
|
175
|
+
reference_input_ids=d_inputs['input_ids'],
|
|
176
|
+
attention_mask=q_inputs['attention_mask'],
|
|
177
|
+
reference_attention_mask=d_inputs['attention_mask'],
|
|
178
|
+
return_dict=True
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Returns matrix: [batch_size, num_references]
|
|
182
|
+
result = {"similarities": outputs['similarities'].tolist()}
|
|
183
|
+
|
|
184
|
+
# -------------------------------------------------------------------------
|
|
185
|
+
# Pipeline: Raw Embeddings
|
|
186
|
+
# -------------------------------------------------------------------------
|
|
187
|
+
elif request.pipeline == "embeddings":
|
|
188
|
+
inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
|
|
189
|
+
|
|
190
|
+
outputs = model(
|
|
191
|
+
input_ids=inputs['input_ids'],
|
|
192
|
+
attention_mask=inputs['attention_mask'],
|
|
193
|
+
return_dict=True
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# 'embeddings' is the normalized pooled output
|
|
197
|
+
result = {"embeddings": outputs['embeddings'].tolist()}
|
|
198
|
+
|
|
199
|
+
# -------------------------------------------------------------------------
|
|
200
|
+
# Pipeline: Masked LM (Raw)
|
|
201
|
+
# -------------------------------------------------------------------------
|
|
202
|
+
elif request.pipeline == "masked-lm":
|
|
203
|
+
inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
|
|
204
|
+
|
|
205
|
+
outputs = model(
|
|
206
|
+
input_ids=inputs['input_ids'],
|
|
207
|
+
attention_mask=inputs['attention_mask'],
|
|
208
|
+
return_dict=True
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Here we return the logits for the mask token if present, else empty.
|
|
212
|
+
|
|
213
|
+
mask_token_id = tokenizer.mask_token_id
|
|
214
|
+
predictions = outputs["logits"]
|
|
215
|
+
mask_positions = mx.argmax(inputs['input_ids'] == mask_token_id, axis=1)
|
|
216
|
+
|
|
217
|
+
batch_results = []
|
|
218
|
+
for i in range(len(texts)):
|
|
219
|
+
if mask_token_id in inputs['input_ids'][i]:
|
|
220
|
+
pos = mask_positions[i].item()
|
|
221
|
+
# Top 5 for the mask
|
|
222
|
+
token_logits = predictions[i, pos]
|
|
223
|
+
probs = mx.softmax(token_logits)
|
|
224
|
+
top_k = 5
|
|
225
|
+
sorted_indices = mx.argsort(probs)[::-1][:top_k]
|
|
226
|
+
|
|
227
|
+
top_tokens = []
|
|
228
|
+
for idx in sorted_indices.tolist():
|
|
229
|
+
top_tokens.append({
|
|
230
|
+
"token": tokenizer.decode([idx]),
|
|
231
|
+
"score": probs[idx].item()
|
|
232
|
+
})
|
|
233
|
+
batch_results.append(top_tokens)
|
|
234
|
+
else:
|
|
235
|
+
batch_results.append(None)
|
|
236
|
+
|
|
237
|
+
result = {"masked_predictions": batch_results}
|
|
238
|
+
|
|
239
|
+
# -------------------------------------------------------------------------
|
|
240
|
+
# Pipeline: Zero-Shot Classification (Custom Logic via Masked LM)
|
|
241
|
+
# -------------------------------------------------------------------------
|
|
242
|
+
elif request.pipeline == "zero-shot-classification":
|
|
243
|
+
if not request.label_candidates:
|
|
244
|
+
raise HTTPException(status_code=400, detail="label_candidates required for zero-shot")
|
|
245
|
+
|
|
246
|
+
# Reuse the logic from your old server, adapted for batching
|
|
247
|
+
if isinstance(request.label_candidates, dict):
|
|
248
|
+
categories = "\n".join([f"{i}: {k} ({v})" for i, (k, v) in enumerate(request.label_candidates.items())])
|
|
249
|
+
num_cats = len(request.label_candidates)
|
|
250
|
+
else:
|
|
251
|
+
categories = "\n".join([f"{i}: {label}" for i, label in enumerate(request.label_candidates)])
|
|
252
|
+
num_cats = len(request.label_candidates)
|
|
253
|
+
|
|
254
|
+
classification_inputs = []
|
|
255
|
+
for text in texts:
|
|
256
|
+
# Answer.ai / ModernBERT style prompt
|
|
257
|
+
classification_input = f"""You will be given a text and categories to classify the text.
|
|
258
|
+
|
|
259
|
+
{text}
|
|
260
|
+
|
|
261
|
+
Read the text carefully and select the right category from the list. Only provide the index of the category:
|
|
262
|
+
{categories}
|
|
263
|
+
|
|
264
|
+
ANSWER: [unused0][MASK]
|
|
265
|
+
"""
|
|
266
|
+
classification_inputs.append(classification_input)
|
|
267
|
+
|
|
268
|
+
inputs = tokenizer._tokenizer(
|
|
269
|
+
classification_inputs,
|
|
270
|
+
return_tensors="mlx",
|
|
271
|
+
padding=True,
|
|
272
|
+
truncation=True,
|
|
273
|
+
max_length=max_len
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
outputs = model(
|
|
277
|
+
input_ids=inputs['input_ids'],
|
|
278
|
+
attention_mask=inputs.get('attention_mask', None),
|
|
279
|
+
return_dict=True
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
predictions = outputs["logits"]
|
|
283
|
+
mask_token_id = tokenizer.mask_token_id
|
|
284
|
+
mask_positions = mx.argmax(inputs['input_ids'] == mask_token_id, axis=1)
|
|
285
|
+
|
|
286
|
+
batch_results = []
|
|
287
|
+
for i in range(len(texts)):
|
|
288
|
+
mask_position = mask_positions[i].item()
|
|
289
|
+
masked_token_predictions = predictions[i, mask_position]
|
|
290
|
+
|
|
291
|
+
probs = mx.softmax(masked_token_predictions)
|
|
292
|
+
top_k = min(5, num_cats)
|
|
293
|
+
|
|
294
|
+
# Sort generic probabilities
|
|
295
|
+
sorted_indices = mx.argsort(probs)[::-1][:top_k]
|
|
296
|
+
top_probs = probs[sorted_indices]
|
|
297
|
+
|
|
298
|
+
item_res = []
|
|
299
|
+
for idx, logit in zip(sorted_indices.tolist(), top_probs.tolist()):
|
|
300
|
+
item_res.append({"label_index": tokenizer.decode([idx]), "score": logit})
|
|
301
|
+
|
|
302
|
+
batch_results.append(item_res)
|
|
303
|
+
|
|
304
|
+
result = {"classification": batch_results}
|
|
305
|
+
|
|
306
|
+
# Clean up
|
|
307
|
+
mx.metal.clear_cache()
|
|
308
|
+
gc.collect()
|
|
309
|
+
|
|
310
|
+
return result
|
|
311
|
+
|
|
312
|
+
@app.get("/status")
|
|
313
|
+
async def status():
|
|
314
|
+
return {
|
|
315
|
+
"status": "online",
|
|
316
|
+
"loaded_model": model_cache.get("model_name"),
|
|
317
|
+
"loaded_pipeline": model_cache.get("pipeline"),
|
|
318
|
+
"loaded_config_key": model_cache.get("key")
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
@app.post("/unload")
|
|
322
|
+
async def unload_model():
|
|
323
|
+
global model_cache
|
|
324
|
+
if not model_cache:
|
|
325
|
+
return {"message": "No model loaded"}
|
|
326
|
+
|
|
327
|
+
name = model_cache.get("model_name")
|
|
328
|
+
model_cache = {}
|
|
329
|
+
gc.collect()
|
|
330
|
+
mx.metal.clear_cache()
|
|
331
|
+
return {"message": f"Unloaded {name}"}
|
|
332
|
+
|
|
333
|
+
if __name__ == "__main__":
|
|
334
|
+
uvicorn.run("mlx_raclate.utils.server:app", host="0.0.0.0", port=8000, workers=1)
|
|
335
|
+
|
|
336
|
+
### EXAMPLE
|
|
337
|
+
'''
|
|
338
|
+
curl -X POST "http://localhost:8000/predict" \
|
|
339
|
+
-H "Content-Type: application/json" \
|
|
340
|
+
-d '{
|
|
341
|
+
"text": [
|
|
342
|
+
"The new MacBook Pro with M3 chip delivers exceptional performance and battery life.",
|
|
343
|
+
"I was really disappointed with the customer service at that restaurant.",
|
|
344
|
+
"This movie has beautiful cinematography but the plot is confusing.",
|
|
345
|
+
"The aging of the population is the archetype of an unpleasant truth for mainstream media readers and for voters, which does not encourage anyone to put it on the table. Age pyramids, birth and fertility indicators, and celibacy rates in all developed countries indicate that the situation is worrying. Among these countries, some managed to stay on-track until about 10 years ago but they eventually fell into line."
|
|
346
|
+
],
|
|
347
|
+
"model": "answerdotai/ModernBERT-Large-Instruct",
|
|
348
|
+
"pipeline": "zero-shot-classification",
|
|
349
|
+
"label_candidates": {
|
|
350
|
+
"artificial intelligence": "The study of computer science that focuses on the creation of intelligent machines that work and react like humans.",
|
|
351
|
+
"physics": "The study of matter, energy, and the fundamental forces of nature.",
|
|
352
|
+
"society" : "The aggregate of people living together in a more or less ordered community.",
|
|
353
|
+
"biology" : "The study of living organisms, divided into many specialized fields that cover their morphology, physiology, anatomy, behavior, origin, and distribution.",
|
|
354
|
+
"environment" : "The surroundings or conditions in which a person, animal, or plant lives or operates.",
|
|
355
|
+
"health" : "The state of being free from illness or injury.",
|
|
356
|
+
"finance" : "The management of large amounts of money, especially by governments or large companies."
|
|
357
|
+
}
|
|
358
|
+
}'
|
|
359
|
+
|
|
360
|
+
'''
|
|
361
|
+
|
|
362
|
+
'''
|
|
363
|
+
curl -X POST "http://localhost:8000/predict" \
|
|
364
|
+
-H "Content-Type: application/json" \
|
|
365
|
+
-d '{
|
|
366
|
+
"model": "NousResearch/Minos-v1",
|
|
367
|
+
"pipeline": "text-classification",
|
|
368
|
+
"text": [
|
|
369
|
+
"I absolutely love this new framework!",
|
|
370
|
+
"The service was terrible and slow."
|
|
371
|
+
]
|
|
372
|
+
}'
|
|
373
|
+
'''
|
|
374
|
+
|
|
375
|
+
'''
|
|
376
|
+
curl -X POST "http://localhost:8000/predict" \
|
|
377
|
+
-H "Content-Type: application/json" \
|
|
378
|
+
-d '{
|
|
379
|
+
"model": "LiquidAI/LFM2-ColBERT-350M",
|
|
380
|
+
"pipeline": "sentence-similarity",
|
|
381
|
+
"config_file": {
|
|
382
|
+
"use_late_interaction": true
|
|
383
|
+
},
|
|
384
|
+
"text": ["What is liquid AI?"],
|
|
385
|
+
"reference_text": [
|
|
386
|
+
"Liquid AI builds efficient foundation models.",
|
|
387
|
+
"Water is a liquid state of matter."
|
|
388
|
+
]
|
|
389
|
+
}'
|
|
390
|
+
'''
|