mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,292 @@
1
+ # Copyright © 2024 Apple Inc.
2
+ from pathlib import Path
3
+ from typing import Dict
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import mlx.optimizers as opt
8
+ from mlx.utils import tree_flatten, tree_unflatten
9
+
10
+ EMBEDDING_LAYER_NAMES = {
11
+ "gemma3_text": ["embed_tokens"],
12
+ "lfm2": ["embed_tokens"],
13
+ "qwen3": ["embed_tokens"],
14
+ "t5gemma_encoder": ["embed_tokens"],
15
+ "modernbert": ["embeddings"]
16
+ }
17
+
18
+ def build_schedule(schedule_config: Dict):
19
+ """
20
+ Build a learning rate schedule from the given config.
21
+ Adapted for MLX Optimizers.
22
+ """
23
+ name = schedule_config["name"]
24
+ arguments = schedule_config["arguments"]
25
+ initial_lr = arguments[0]
26
+
27
+ # Create the main schedule function
28
+ if name == "constant":
29
+ # Create a callable that ignores the step and returns the LR
30
+ bound_schedule_fn = lambda _: initial_lr
31
+ else:
32
+ # For cosine_decay, linear_schedule, etc.
33
+ schedule_fn = getattr(opt, name)
34
+ bound_schedule_fn = schedule_fn(*arguments)
35
+
36
+ # Check for warmup
37
+ warmup_steps = schedule_config.get("warmup_steps", 0)
38
+
39
+ if warmup_steps > 0:
40
+ warmup_init = schedule_config.get("warmup_init", 0.0)
41
+
42
+ # Linear warmup: 0 -> initial_lr
43
+ warmup_fn = opt.linear_schedule(
44
+ warmup_init, initial_lr, warmup_steps
45
+ )
46
+
47
+ return opt.join_schedules(
48
+ [warmup_fn, bound_schedule_fn], [warmup_steps]
49
+ )
50
+ else:
51
+ return bound_schedule_fn
52
+
53
+ ### We don't currently use LoRA or DoRA in MLX Raclate tuner
54
+ ### Kept here for reference, and if we need to re-implement it later
55
+ # def linear_to_lora_layers(
56
+ # model: nn.Module,
57
+ # num_layers: int,
58
+ # config: Dict,
59
+ # use_dora: bool = False,
60
+ # ):
61
+ # """
62
+ # Convert some of the models linear layers to lora layers.
63
+
64
+ # Args:
65
+ # model (nn.Module): The neural network model.
66
+ # num_layers (int): The number of blocks to convert to lora layers
67
+ # starting from the last layer.
68
+ # config (dict): More configuration parameters for LoRA, including the
69
+ # rank, scale, and optional layer keys.
70
+ # use_dora (bool): If True, uses DoRA instead of LoRA.
71
+ # Default: ``False``
72
+ # """
73
+ # if num_layers > len(model.layers):
74
+ # raise ValueError(
75
+ # f"Requested {num_layers} LoRA layers "
76
+ # f"but the model only has {len(model.layers)} layers."
77
+ # )
78
+
79
+ # def to_lora(layer):
80
+ # if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
81
+ # LoRALayer = DoRALinear if use_dora else LoRALinear
82
+ # elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
83
+ # if use_dora:
84
+ # raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
85
+ # LoRALayer = LoRASwitchLinear
86
+ # elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
87
+ # LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
88
+ # else:
89
+ # raise ValueError(
90
+ # f"Can't convert layer of type {type(layer).__name__} to LoRA"
91
+ # )
92
+
93
+ # return LoRALayer.from_base(
94
+ # layer,
95
+ # r=config["rank"],
96
+ # scale=config["scale"],
97
+ # dropout=config["dropout"],
98
+ # )
99
+
100
+ # keys = config.get("keys", None)
101
+ # if keys is not None:
102
+ # keys = set(keys)
103
+ # elif model.model_type in [
104
+ # "mistral",
105
+ # "llama",
106
+ # "phi",
107
+ # "mixtral",
108
+ # "nemotron",
109
+ # "stablelm",
110
+ # "qwen2",
111
+ # "qwen2_moe",
112
+ # "phimoe",
113
+ # "gemma",
114
+ # "gemma2",
115
+ # "starcoder2",
116
+ # "cohere",
117
+ # "cohere2",
118
+ # "minicpm",
119
+ # "deepseek",
120
+ # "olmo2",
121
+ # ]:
122
+ # keys = set(["self_attn.q_proj", "self_attn.v_proj"])
123
+ # if model.model_type in ["mixtral", "phimoe"]:
124
+ # keys.add("block_sparse_moe.gate")
125
+ # if model.model_type == "qwen2_moe":
126
+ # keys.add("mlp.gate")
127
+ # keys.add("mlp.shared_expert_gate")
128
+
129
+ # elif model.model_type == "gpt_bigcode":
130
+ # keys = set(["attn.c_attn"])
131
+ # elif model.model_type == "gpt2":
132
+ # keys = set(["attn.c_attn"])
133
+ # elif model.model_type == "gpt_neox":
134
+ # keys = set(["attention.query_key_value"])
135
+ # elif model.model_type == "olmo":
136
+ # keys = set(["att_proj"])
137
+ # elif model.model_type == "openelm":
138
+ # keys = set(["attn.qkv_proj"])
139
+ # elif model.model_type == "phi3":
140
+ # keys = set(["self_attn.qkv_proj"])
141
+ # elif model.model_type == "phi-msft":
142
+ # keys = set(["mixer.Wqkv", "moe.gate"])
143
+ # elif model.model_type == "dbrx":
144
+ # keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
145
+ # elif model.model_type == "internlm2":
146
+ # keys = set(["attention.wqkv", "attention.wo"])
147
+ # elif model.model_type == "deepseek_v2":
148
+ # keys = set(
149
+ # [
150
+ # "self_attn.q_proj",
151
+ # "self_attn.q_a_proj",
152
+ # "self_attn.q_b_proj",
153
+ # "self_attn.kv_a_proj_with_mqa",
154
+ # "self_attn.kv_b_proj",
155
+ # ]
156
+ # )
157
+ # elif model.model_type == "mamba":
158
+ # keys = set(
159
+ # [
160
+ # "mixer.in_proj",
161
+ # "mixer.x_proj",
162
+ # "mixer.dt_proj",
163
+ # "mixer.out_proj",
164
+ # ]
165
+ # )
166
+ # elif model.model_type == "exaone":
167
+ # keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
168
+ # else:
169
+ # raise ValueError(f"Lora does not support {model.model_type}")
170
+
171
+ # for l in model.layers[-min(num_layers, 0) :]:
172
+ # lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
173
+ # if lora_layers:
174
+ # l.update_modules(tree_unflatten(lora_layers))
175
+
176
+ # lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
177
+ # if lora_modules:
178
+ # model.update_modules(tree_unflatten(lora_modules))
179
+
180
+
181
+ # def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
182
+ # """
183
+ # Load any fine-tuned adapters / layers.
184
+
185
+ # Args:
186
+ # model (nn.Module): The neural network model.
187
+ # adapter_path (str): Path to the adapter configuration file.
188
+
189
+ # Returns:
190
+ # nn.Module: The updated model with LoRA layers applied.
191
+ # """
192
+ # adapter_path = Path(adapter_path)
193
+ # if not adapter_path.exists():
194
+ # raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
195
+ # with open(adapter_path / "adapter_config.json", "r") as fid:
196
+ # config = types.SimpleNamespace(**json.load(fid))
197
+ # fine_tune_type = getattr(config, "fine_tune_type", "lora")
198
+ # if fine_tune_type != "full":
199
+ # linear_to_lora_layers(
200
+ # model,
201
+ # config.num_layers,
202
+ # config.lora_parameters,
203
+ # use_dora=(fine_tune_type == "dora"),
204
+ # )
205
+ # model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
206
+ # return model
207
+
208
+
209
+ def dequantize(model: nn.Module) -> nn.Module:
210
+ """
211
+ Dequantize the quantized linear layers in the model.
212
+
213
+ Args:
214
+ model (nn.Module): The model with quantized linear layers.
215
+
216
+ Returns:
217
+ nn.Module: The model with dequantized layers.
218
+ """
219
+ de_quantize_layers = []
220
+ for name, module in model.named_modules():
221
+ if isinstance(module, nn.QuantizedLinear):
222
+ bias = "bias" in module
223
+ weight = module.weight
224
+ weight = mx.dequantize(
225
+ weight,
226
+ module.scales,
227
+ module.biases,
228
+ module.group_size,
229
+ module.bits,
230
+ ).astype(mx.float16)
231
+ output_dims, input_dims = weight.shape
232
+ linear = nn.Linear(input_dims, output_dims, bias=bias)
233
+ linear.weight = weight
234
+ if bias:
235
+ linear.bias = module.bias
236
+ de_quantize_layers.append((name, linear))
237
+ if isinstance(module, nn.QuantizedEmbedding):
238
+ weight = mx.dequantize(
239
+ module.weight,
240
+ module.scales,
241
+ module.biases,
242
+ module.group_size,
243
+ module.bits,
244
+ ).astype(mx.float16)
245
+ num_embeddings, dims = weight.shape
246
+ emb = nn.Embedding(num_embeddings, dims)
247
+ emb.weight = weight
248
+ de_quantize_layers.append((name, emb))
249
+
250
+ if len(de_quantize_layers) > 0:
251
+ model.update_modules(tree_unflatten(de_quantize_layers))
252
+ return model
253
+
254
+
255
+ # def remove_lora_layers(model: nn.Module) -> nn.Module:
256
+ # """
257
+ # Remove the LoRA layers from the model.
258
+
259
+ # Args:
260
+ # model (nn.Module): The model with LoRA layers.
261
+
262
+ # Returns:
263
+ # nn.Module: The model without LoRA layers.
264
+ # """
265
+ # reset_layers = []
266
+ # for name, module in model.named_modules():
267
+ # if isinstance(module, LoRALinear):
268
+ # reset_layers.append((name, module.linear))
269
+ # if len(reset_layers) > 0:
270
+ # model.update_modules(tree_unflatten(reset_layers))
271
+ # return model
272
+
273
+
274
+ def nparams(module):
275
+ if hasattr(module, "bits"):
276
+ n = 0 if not hasattr(module, "bias") else module.bias.size
277
+ return n + module.weight.size * 32 // module.bits
278
+ return sum(v.size for _, v in tree_flatten(module.parameters()))
279
+
280
+
281
+ def print_trainable_parameters(model):
282
+ leaf_modules = tree_flatten(
283
+ model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
284
+ )
285
+ total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
286
+ trainable_p = (
287
+ sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
288
+ )
289
+ print(
290
+ f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
291
+ f"({trainable_p:.3f}M/{total_p:.3f}M)"
292
+ )
File without changes
@@ -0,0 +1,390 @@
1
+ # server.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Optional, Any, Union, Dict
5
+ import uvicorn
6
+ import gc
7
+ import mlx.core as mx
8
+ from mlx_raclate.utils.utils import PIPELINES, load
9
+
10
+ app = FastAPI(
11
+ title="Raclate Inference API",
12
+ description="API for using Raclate pipelines (ModernBERT, LFM2, Qwen, etc.) on Apple Silicon",
13
+ version="0.1.0"
14
+ )
15
+
16
+ # TODO:
17
+ # Separate Services: For complete isolation, run each pipeline type as a separate FastAPI service and use a lightweight API gateway to route requests.
18
+ # Worker Pool Architecture: Implement a worker pool where each worker specializes in a specific pipeline, and a dispatcher routes requests to the appropriate worker.
19
+
20
+
21
+ model_cache = {}
22
+
23
+ def get_model(model_name: str, pipeline_name: str, config_file: Optional[Dict] = None):
24
+ """
25
+ Factory function to get or create the appropriate model.
26
+ Checks cache based on name, pipeline, AND configuration.
27
+ """
28
+ global model_cache
29
+
30
+ # Create a cache key string that includes config to differentiate
31
+ # e.g. LFM2 with late_interaction=True vs False
32
+ config_key = str(sorted(config_file.items())) if config_file else "default"
33
+
34
+ current_key = f"{model_name}_{pipeline_name}_{config_key}"
35
+ cached_key = model_cache.get("key", None)
36
+
37
+ if cached_key == current_key:
38
+ return model_cache
39
+
40
+ # Garbage collection before loading new model
41
+ if model_cache:
42
+ print(f"Unloading previous model: {model_cache.get('model_name')}")
43
+ model_cache = {}
44
+ mx.eval() # Ensure evaluation of any pending ops
45
+ gc.collect()
46
+ mx.metal.clear_cache()
47
+
48
+ print(f"Loading model: {model_name} | Pipeline: {pipeline_name} | Config: {config_file}")
49
+
50
+ try:
51
+ model, tokenizer = load(
52
+ model_name,
53
+ pipeline=pipeline_name,
54
+ model_config=config_file
55
+ )
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
58
+
59
+ model_cache = {
60
+ "key": current_key,
61
+ "model_name": model_name,
62
+ "pipeline": pipeline_name,
63
+ "model": model,
64
+ "tokenizer": tokenizer,
65
+ }
66
+
67
+ return model_cache
68
+
69
+ # -----------------------------------------------------------------------------
70
+ # Pydantic Models
71
+ # -----------------------------------------------------------------------------
72
+
73
+ class PredictionRequest(BaseModel):
74
+ model: str
75
+ pipeline: str
76
+ text: Union[str, List[str]]
77
+
78
+ # Optional parameters depending on pipeline
79
+ text_pair: Optional[Union[str, List[str]]] = Field(None, description="Secondary text for sequence classification pairs (e.g. NLI)")
80
+ reference_text: Optional[Union[str, List[str]]] = Field(None, description="Documents/References for similarity search")
81
+
82
+ # Configuration
83
+ config_file: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Configuration overrides (e.g., {'use_late_interaction': True})")
84
+ label_candidates: Optional[Union[Dict[str, str], List[str]]] = Field(None, description="For zero-shot-classification only")
85
+
86
+ # -----------------------------------------------------------------------------
87
+ # Endpoints
88
+ # -----------------------------------------------------------------------------
89
+
90
+ @app.post("/predict")
91
+ async def predict(request: PredictionRequest):
92
+ """
93
+ Main inference endpoint handling multiple pipelines.
94
+ """
95
+ if request.pipeline not in PIPELINES:
96
+ raise HTTPException(status_code=400, detail=f"Pipeline '{request.pipeline}' not supported. Available: {PIPELINES}")
97
+
98
+ # Standardize inputs to lists
99
+ texts = request.text if isinstance(request.text, list) else [request.text]
100
+
101
+ if len(texts) > 32:
102
+ raise HTTPException(status_code=400, detail="Batch size should not exceed 32 to protect memory")
103
+
104
+ # Load Model
105
+ model_info = get_model(request.model, request.pipeline, request.config_file)
106
+ tokenizer = model_info["tokenizer"]
107
+ model = model_info["model"]
108
+
109
+ # Determine generic args
110
+ max_len = getattr(model.config, "max_position_embeddings", 512)
111
+ result = {}
112
+
113
+ # -------------------------------------------------------------------------
114
+ # Pipeline: Text Classification (Sentiment, NLI, Regression)
115
+ # -------------------------------------------------------------------------
116
+ if request.pipeline == "text-classification":
117
+ text_pairs = None
118
+ if request.text_pair:
119
+ text_pairs = request.text_pair if isinstance(request.text_pair, list) else [request.text_pair]
120
+ if len(text_pairs) != len(texts):
121
+ raise HTTPException(status_code=400, detail="Length of text and text_pair must match")
122
+
123
+ inputs = tokenizer._tokenizer(
124
+ texts,
125
+ text_pairs,
126
+ return_tensors="mlx",
127
+ padding=True,
128
+ truncation=True,
129
+ max_length=max_len
130
+ )
131
+
132
+ outputs = model(
133
+ input_ids=inputs['input_ids'],
134
+ attention_mask=inputs['attention_mask'],
135
+ return_dict=True
136
+ )
137
+
138
+ probs = outputs["probabilities"] # Shape: [batch, num_labels]
139
+
140
+ # Format output
141
+ batch_results = []
142
+ id2label = getattr(model.config, "id2label", None)
143
+
144
+ # Convert to python list structure
145
+ probs_list = probs.tolist()
146
+
147
+ for i, row in enumerate(probs_list):
148
+ if id2label:
149
+ # Return dictionary mapping label -> score
150
+ item_res = [[id2label[str(j)], score] for j, score in enumerate(row)]
151
+ # Sort by score descending
152
+ item_res = sorted(item_res, key=lambda x: x[1], reverse=True)
153
+ batch_results.append(item_res)
154
+ else:
155
+ # Just return raw scores (e.g. regression or missing config)
156
+ batch_results.append(row)
157
+
158
+ result = {"predictions": batch_results}
159
+
160
+ # -------------------------------------------------------------------------
161
+ # Pipeline: Sentence Similarity (Dense & Late Interaction)
162
+ # -------------------------------------------------------------------------
163
+ elif request.pipeline in ["sentence-similarity", "sentence-transformers"]:
164
+ if not request.reference_text:
165
+ raise HTTPException(status_code=400, detail="reference_text is required for sentence-similarity")
166
+
167
+ refs = request.reference_text if isinstance(request.reference_text, list) else [request.reference_text]
168
+
169
+ q_inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
170
+ d_inputs = tokenizer._tokenizer(refs, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
171
+
172
+ # The model handles the complexity (Cosine vs MaxSim) internally based on config
173
+ outputs = model(
174
+ input_ids=q_inputs['input_ids'],
175
+ reference_input_ids=d_inputs['input_ids'],
176
+ attention_mask=q_inputs['attention_mask'],
177
+ reference_attention_mask=d_inputs['attention_mask'],
178
+ return_dict=True
179
+ )
180
+
181
+ # Returns matrix: [batch_size, num_references]
182
+ result = {"similarities": outputs['similarities'].tolist()}
183
+
184
+ # -------------------------------------------------------------------------
185
+ # Pipeline: Raw Embeddings
186
+ # -------------------------------------------------------------------------
187
+ elif request.pipeline == "embeddings":
188
+ inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
189
+
190
+ outputs = model(
191
+ input_ids=inputs['input_ids'],
192
+ attention_mask=inputs['attention_mask'],
193
+ return_dict=True
194
+ )
195
+
196
+ # 'embeddings' is the normalized pooled output
197
+ result = {"embeddings": outputs['embeddings'].tolist()}
198
+
199
+ # -------------------------------------------------------------------------
200
+ # Pipeline: Masked LM (Raw)
201
+ # -------------------------------------------------------------------------
202
+ elif request.pipeline == "masked-lm":
203
+ inputs = tokenizer._tokenizer(texts, return_tensors="mlx", padding=True, truncation=True, max_length=max_len)
204
+
205
+ outputs = model(
206
+ input_ids=inputs['input_ids'],
207
+ attention_mask=inputs['attention_mask'],
208
+ return_dict=True
209
+ )
210
+
211
+ # Here we return the logits for the mask token if present, else empty.
212
+
213
+ mask_token_id = tokenizer.mask_token_id
214
+ predictions = outputs["logits"]
215
+ mask_positions = mx.argmax(inputs['input_ids'] == mask_token_id, axis=1)
216
+
217
+ batch_results = []
218
+ for i in range(len(texts)):
219
+ if mask_token_id in inputs['input_ids'][i]:
220
+ pos = mask_positions[i].item()
221
+ # Top 5 for the mask
222
+ token_logits = predictions[i, pos]
223
+ probs = mx.softmax(token_logits)
224
+ top_k = 5
225
+ sorted_indices = mx.argsort(probs)[::-1][:top_k]
226
+
227
+ top_tokens = []
228
+ for idx in sorted_indices.tolist():
229
+ top_tokens.append({
230
+ "token": tokenizer.decode([idx]),
231
+ "score": probs[idx].item()
232
+ })
233
+ batch_results.append(top_tokens)
234
+ else:
235
+ batch_results.append(None)
236
+
237
+ result = {"masked_predictions": batch_results}
238
+
239
+ # -------------------------------------------------------------------------
240
+ # Pipeline: Zero-Shot Classification (Custom Logic via Masked LM)
241
+ # -------------------------------------------------------------------------
242
+ elif request.pipeline == "zero-shot-classification":
243
+ if not request.label_candidates:
244
+ raise HTTPException(status_code=400, detail="label_candidates required for zero-shot")
245
+
246
+ # Reuse the logic from your old server, adapted for batching
247
+ if isinstance(request.label_candidates, dict):
248
+ categories = "\n".join([f"{i}: {k} ({v})" for i, (k, v) in enumerate(request.label_candidates.items())])
249
+ num_cats = len(request.label_candidates)
250
+ else:
251
+ categories = "\n".join([f"{i}: {label}" for i, label in enumerate(request.label_candidates)])
252
+ num_cats = len(request.label_candidates)
253
+
254
+ classification_inputs = []
255
+ for text in texts:
256
+ # Answer.ai / ModernBERT style prompt
257
+ classification_input = f"""You will be given a text and categories to classify the text.
258
+
259
+ {text}
260
+
261
+ Read the text carefully and select the right category from the list. Only provide the index of the category:
262
+ {categories}
263
+
264
+ ANSWER: [unused0][MASK]
265
+ """
266
+ classification_inputs.append(classification_input)
267
+
268
+ inputs = tokenizer._tokenizer(
269
+ classification_inputs,
270
+ return_tensors="mlx",
271
+ padding=True,
272
+ truncation=True,
273
+ max_length=max_len
274
+ )
275
+
276
+ outputs = model(
277
+ input_ids=inputs['input_ids'],
278
+ attention_mask=inputs.get('attention_mask', None),
279
+ return_dict=True
280
+ )
281
+
282
+ predictions = outputs["logits"]
283
+ mask_token_id = tokenizer.mask_token_id
284
+ mask_positions = mx.argmax(inputs['input_ids'] == mask_token_id, axis=1)
285
+
286
+ batch_results = []
287
+ for i in range(len(texts)):
288
+ mask_position = mask_positions[i].item()
289
+ masked_token_predictions = predictions[i, mask_position]
290
+
291
+ probs = mx.softmax(masked_token_predictions)
292
+ top_k = min(5, num_cats)
293
+
294
+ # Sort generic probabilities
295
+ sorted_indices = mx.argsort(probs)[::-1][:top_k]
296
+ top_probs = probs[sorted_indices]
297
+
298
+ item_res = []
299
+ for idx, logit in zip(sorted_indices.tolist(), top_probs.tolist()):
300
+ item_res.append({"label_index": tokenizer.decode([idx]), "score": logit})
301
+
302
+ batch_results.append(item_res)
303
+
304
+ result = {"classification": batch_results}
305
+
306
+ # Clean up
307
+ mx.metal.clear_cache()
308
+ gc.collect()
309
+
310
+ return result
311
+
312
+ @app.get("/status")
313
+ async def status():
314
+ return {
315
+ "status": "online",
316
+ "loaded_model": model_cache.get("model_name"),
317
+ "loaded_pipeline": model_cache.get("pipeline"),
318
+ "loaded_config_key": model_cache.get("key")
319
+ }
320
+
321
+ @app.post("/unload")
322
+ async def unload_model():
323
+ global model_cache
324
+ if not model_cache:
325
+ return {"message": "No model loaded"}
326
+
327
+ name = model_cache.get("model_name")
328
+ model_cache = {}
329
+ gc.collect()
330
+ mx.metal.clear_cache()
331
+ return {"message": f"Unloaded {name}"}
332
+
333
+ if __name__ == "__main__":
334
+ uvicorn.run("mlx_raclate.utils.server:app", host="0.0.0.0", port=8000, workers=1)
335
+
336
+ ### EXAMPLE
337
+ '''
338
+ curl -X POST "http://localhost:8000/predict" \
339
+ -H "Content-Type: application/json" \
340
+ -d '{
341
+ "text": [
342
+ "The new MacBook Pro with M3 chip delivers exceptional performance and battery life.",
343
+ "I was really disappointed with the customer service at that restaurant.",
344
+ "This movie has beautiful cinematography but the plot is confusing.",
345
+ "The aging of the population is the archetype of an unpleasant truth for mainstream media readers and for voters, which does not encourage anyone to put it on the table. Age pyramids, birth and fertility indicators, and celibacy rates in all developed countries indicate that the situation is worrying. Among these countries, some managed to stay on-track until about 10 years ago but they eventually fell into line."
346
+ ],
347
+ "model": "answerdotai/ModernBERT-Large-Instruct",
348
+ "pipeline": "zero-shot-classification",
349
+ "label_candidates": {
350
+ "artificial intelligence": "The study of computer science that focuses on the creation of intelligent machines that work and react like humans.",
351
+ "physics": "The study of matter, energy, and the fundamental forces of nature.",
352
+ "society" : "The aggregate of people living together in a more or less ordered community.",
353
+ "biology" : "The study of living organisms, divided into many specialized fields that cover their morphology, physiology, anatomy, behavior, origin, and distribution.",
354
+ "environment" : "The surroundings or conditions in which a person, animal, or plant lives or operates.",
355
+ "health" : "The state of being free from illness or injury.",
356
+ "finance" : "The management of large amounts of money, especially by governments or large companies."
357
+ }
358
+ }'
359
+
360
+ '''
361
+
362
+ '''
363
+ curl -X POST "http://localhost:8000/predict" \
364
+ -H "Content-Type: application/json" \
365
+ -d '{
366
+ "model": "NousResearch/Minos-v1",
367
+ "pipeline": "text-classification",
368
+ "text": [
369
+ "I absolutely love this new framework!",
370
+ "The service was terrible and slow."
371
+ ]
372
+ }'
373
+ '''
374
+
375
+ '''
376
+ curl -X POST "http://localhost:8000/predict" \
377
+ -H "Content-Type: application/json" \
378
+ -d '{
379
+ "model": "LiquidAI/LFM2-ColBERT-350M",
380
+ "pipeline": "sentence-similarity",
381
+ "config_file": {
382
+ "use_late_interaction": true
383
+ },
384
+ "text": ["What is liquid AI?"],
385
+ "reference_text": [
386
+ "Liquid AI builds efficient foundation models.",
387
+ "Water is a liquid state of matter."
388
+ ]
389
+ }'
390
+ '''