nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (35) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  5. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  6. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  8. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  9. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  10. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  12. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  13. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  14. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  15. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  16. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
  17. nexaai/mlx_backend/vlm/interface.py +21 -4
  18. nexaai/mlx_backend/vlm/main.py +6 -2
  19. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  20. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  21. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  22. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  23. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  24. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  25. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  26. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  27. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  28. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  29. nexaai/utils/manifest_utils.py +222 -15
  30. nexaai/utils/model_manager.py +83 -7
  31. nexaai/utils/model_types.py +2 -0
  32. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  33. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +35 -24
  34. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  35. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ import math
2
+ import mlx.core as mx
3
+ import mlx.nn as nn
4
+
5
+ def _gather_sort(x, indices):
6
+ *_, M = indices.shape
7
+ indices = indices.flatten()
8
+ order = mx.argsort(indices)
9
+ inv_order = mx.argsort(order)
10
+ return x.flatten(0, -3)[order // M], indices[order], inv_order
11
+
12
+
13
+ def _scatter_unsort(x, inv_order, shape=None):
14
+ x = x[inv_order]
15
+ if shape is not None:
16
+ x = mx.unflatten(x, 0, shape)
17
+ return x
18
+
19
+
20
+ class QuantizedSwitchLinear(nn.Module):
21
+ def __init__(
22
+ self,
23
+ input_dims: int,
24
+ output_dims: int,
25
+ num_experts: int,
26
+ bias: bool = True,
27
+ group_size: int = 64,
28
+ bits: int = 4,
29
+ ):
30
+ super().__init__()
31
+
32
+ scale = math.sqrt(1 / input_dims)
33
+ self.weight, self.scales, self.biases = mx.quantize(
34
+ mx.random.uniform(
35
+ low=-scale,
36
+ high=scale,
37
+ shape=(num_experts, output_dims, input_dims),
38
+ ),
39
+ group_size=group_size,
40
+ bits=bits,
41
+ )
42
+
43
+ if bias:
44
+ self.bias = mx.zeros((num_experts, output_dims))
45
+
46
+ self.group_size = group_size
47
+ self.bits = bits
48
+
49
+ # Freeze this model's parameters
50
+ self.freeze()
51
+
52
+ def unfreeze(self, *args, **kwargs):
53
+ """Wrap unfreeze so that we unfreeze any layers we might contain but
54
+ our parameters will remain frozen."""
55
+ super().unfreeze(*args, **kwargs)
56
+ self.freeze(recurse=False)
57
+
58
+ @property
59
+ def input_dims(self):
60
+ return self.scales.shape[2] * self.group_size
61
+
62
+ @property
63
+ def output_dims(self):
64
+ return self.weight.shape[1]
65
+
66
+ @property
67
+ def num_experts(self):
68
+ return self.weight.shape[0]
69
+
70
+ def __call__(self, x, indices, sorted_indices=False):
71
+ x = mx.gather_qmm(
72
+ x,
73
+ self["weight"],
74
+ self["scales"],
75
+ self["biases"],
76
+ rhs_indices=indices,
77
+ transpose=True,
78
+ group_size=self.group_size,
79
+ bits=self.bits,
80
+ sorted_indices=sorted_indices,
81
+ )
82
+ if "bias" in self:
83
+ x = x + mx.expand_dims(self["bias"][indices], -2)
84
+ return x
85
+
86
+
87
+ class SwitchLinear(nn.Module):
88
+ def __init__(
89
+ self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
90
+ ):
91
+ super().__init__()
92
+ scale = math.sqrt(1 / input_dims)
93
+ self.weight = mx.random.uniform(
94
+ low=-scale,
95
+ high=scale,
96
+ shape=(num_experts, output_dims, input_dims),
97
+ )
98
+
99
+ if bias:
100
+ self.bias = mx.zeros((num_experts, output_dims))
101
+
102
+ @property
103
+ def input_dims(self):
104
+ return self.weight.shape[2]
105
+
106
+ @property
107
+ def output_dims(self):
108
+ return self.weight.shape[1]
109
+
110
+ @property
111
+ def num_experts(self):
112
+ return self.weight.shape[0]
113
+
114
+ def __call__(self, x, indices, sorted_indices=False):
115
+ x = mx.gather_mm(
116
+ x,
117
+ self["weight"].swapaxes(-1, -2),
118
+ lhs_indices=None,
119
+ rhs_indices=indices,
120
+ )
121
+ if "bias" in self:
122
+ x = x + mx.expand_dims(self["bias"][indices], -2)
123
+ return x
124
+
125
+ def to_quantized(self, group_size: int = 64, bits: int = 4):
126
+ num_experts, output_dims, input_dims = self.weight.shape
127
+ ql = QuantizedSwitchLinear(
128
+ input_dims, output_dims, num_experts, False, group_size, bits
129
+ )
130
+ ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
131
+ if "bias" in self:
132
+ ql.bias = self.bias
133
+ return ql
134
+
135
+
136
+ class SwitchGLU(nn.Module):
137
+ def __init__(
138
+ self,
139
+ input_dims: int,
140
+ hidden_dims: int,
141
+ num_experts: int,
142
+ activation=nn.SiLU(),
143
+ bias: bool = False,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
148
+ self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
149
+ self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
150
+ self.activation = activation
151
+
152
+ def __call__(self, x, indices) -> mx.array:
153
+ x = mx.expand_dims(x, (-2, -3))
154
+
155
+ # When we have many tokens, then sort them to make sure that the access
156
+ # of different experts is in order.
157
+ do_sort = indices.size >= 64
158
+ idx = indices
159
+ inv_order = None
160
+ if do_sort:
161
+ x, idx, inv_order = _gather_sort(x, indices)
162
+
163
+ x_up = self.up_proj(x, idx, sorted_indices=do_sort)
164
+ x_gate = self.gate_proj(x, idx, sorted_indices=do_sort)
165
+ x = self.down_proj(
166
+ self.activation(x_gate) * x_up,
167
+ idx,
168
+ sorted_indices=do_sort,
169
+ )
170
+
171
+ if do_sort:
172
+ x = _scatter_unsort(x, inv_order, indices.shape)
173
+
174
+ return x.squeeze(-2)
175
+
176
+
177
+ class SwitchMLP(nn.Module):
178
+ def __init__(
179
+ self,
180
+ input_dims: int,
181
+ hidden_dims: int,
182
+ num_experts: int,
183
+ activation=nn.GELU(approx="precise"),
184
+ bias: bool = False,
185
+ ):
186
+ super().__init__()
187
+
188
+ self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
189
+ self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
190
+ self.activation = activation
191
+
192
+ def __call__(self, x, indices) -> mx.array:
193
+ x = mx.expand_dims(x, (-2, -3))
194
+
195
+ # When we have many tokens, then sort them to make sure that the access
196
+ # of different experts is in order.
197
+ do_sort = indices.size >= 64
198
+ idx = indices
199
+ inv_order = None
200
+ if do_sort:
201
+ x, idx, inv_order = _gather_sort(x, indices)
202
+
203
+ x = self.fc1(x, idx, sorted_indices=do_sort)
204
+ x = self.activation(x)
205
+ x = self.fc2(x, idx, sorted_indices=do_sort)
206
+
207
+ if do_sort:
208
+ x = _scatter_unsort(x, inv_order, indices.shape)
209
+
210
+ return x.squeeze(-2)
@@ -22,6 +22,11 @@ from .model_types import (
22
22
  MODEL_TYPE_TO_PIPELINE
23
23
  )
24
24
 
25
+ MODEL_FILE_TYPE_TO_PLUGIN_ID_MAPPING = {
26
+ 'npu': 'npu',
27
+ 'mlx': 'mlx',
28
+ 'gguf': 'llama_cpp'
29
+ }
25
30
 
26
31
  def process_manifest_metadata(manifest: Dict[str, Any], repo_id: str) -> Dict[str, Any]:
27
32
  """Process manifest metadata to handle null/missing fields."""
@@ -94,12 +99,20 @@ def save_download_metadata(directory_path: str, metadata: Dict[str, Any]) -> Non
94
99
  pass
95
100
 
96
101
 
97
- def create_gguf_manifest(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None) -> Dict[str, Any]:
102
+ def _get_plugin_id_from_model_file_type(model_file_type: Optional[str], default: str = "llama_cpp") -> str:
103
+ """Map model file type to PluginId."""
104
+ return MODEL_FILE_TYPE_TO_PLUGIN_ID_MAPPING.get(model_file_type, default)
105
+
106
+
107
+ def create_gguf_manifest(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> Dict[str, Any]:
98
108
  """Create GGUF format manifest."""
99
109
 
100
110
  # Load existing manifest to merge GGUF files if it exists
101
111
  existing_manifest = load_nexa_manifest(directory_path)
102
112
 
113
+ # Check if there's a downloaded nexa.manifest from the repo
114
+ downloaded_manifest = old_metadata.get('downloaded_manifest', {})
115
+
103
116
  model_files = {}
104
117
  if existing_manifest and "ModelFile" in existing_manifest:
105
118
  model_files = existing_manifest["ModelFile"].copy()
@@ -151,10 +164,41 @@ def create_gguf_manifest(repo_id: str, files: List[str], directory_path: str, ol
151
164
  "Size": file_size
152
165
  }
153
166
 
167
+ # Determine PluginId with priority: kwargs > downloaded_manifest > model_file_type > default
168
+ plugin_id = kwargs.get('plugin_id')
169
+ if not plugin_id:
170
+ model_file_type = old_metadata.get('model_file_type')
171
+ if downloaded_manifest.get('PluginId'):
172
+ plugin_id = downloaded_manifest.get('PluginId')
173
+ elif model_file_type:
174
+ plugin_id = _get_plugin_id_from_model_file_type(model_file_type)
175
+ else:
176
+ plugin_id = "llama_cpp"
177
+
178
+ # Determine ModelType with priority: kwargs > downloaded_manifest > pipeline_tag mapping
179
+ model_type = kwargs.get('model_type')
180
+ if not model_type:
181
+ if downloaded_manifest.get('ModelType'):
182
+ model_type = downloaded_manifest.get('ModelType')
183
+ else:
184
+ model_type = PIPELINE_TO_MODEL_TYPE.get(old_metadata.get('pipeline_tag'), "other")
185
+
186
+ # Determine ModelName with priority: kwargs > downloaded_manifest > empty string
187
+ model_name = kwargs.get('model_name')
188
+ if not model_name:
189
+ model_name = downloaded_manifest.get('ModelName', '')
190
+
191
+ # Get DeviceId and MinSDKVersion from kwargs or default to empty string
192
+ device_id = kwargs.get('device_id', '')
193
+ min_sdk_version = kwargs.get('min_sdk_version', '')
194
+
154
195
  manifest = {
155
196
  "Name": repo_id,
156
- "ModelType": PIPELINE_TO_MODEL_TYPE.get(old_metadata.get('pipeline_tag'), "other"),
157
- "PluginId": "llama_cpp",
197
+ "ModelName": model_name,
198
+ "ModelType": model_type,
199
+ "PluginId": plugin_id,
200
+ "DeviceId": device_id,
201
+ "MinSDKVersion": min_sdk_version,
158
202
  "ModelFile": model_files,
159
203
  "MMProjFile": mmproj_file,
160
204
  "TokenizerFile": {
@@ -172,12 +216,15 @@ def create_gguf_manifest(repo_id: str, files: List[str], directory_path: str, ol
172
216
  return manifest
173
217
 
174
218
 
175
- def create_mlx_manifest(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None) -> Dict[str, Any]:
219
+ def create_mlx_manifest(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> Dict[str, Any]:
176
220
  """Create MLX format manifest."""
177
221
 
178
222
  # Load existing manifest to merge MLX files if it exists
179
223
  existing_manifest = load_nexa_manifest(directory_path)
180
224
 
225
+ # Check if there's a downloaded nexa.manifest from the repo
226
+ downloaded_manifest = old_metadata.get('downloaded_manifest', {})
227
+
181
228
  model_files = {}
182
229
  extra_files = []
183
230
 
@@ -233,10 +280,153 @@ def create_mlx_manifest(repo_id: str, files: List[str], directory_path: str, old
233
280
  "Size": file_size
234
281
  })
235
282
 
283
+ # Determine PluginId with priority: kwargs > downloaded_manifest > model_file_type > default
284
+ plugin_id = kwargs.get('plugin_id')
285
+ if not plugin_id:
286
+ model_file_type = old_metadata.get('model_file_type')
287
+ if downloaded_manifest.get('PluginId'):
288
+ plugin_id = downloaded_manifest.get('PluginId')
289
+ elif model_file_type:
290
+ plugin_id = _get_plugin_id_from_model_file_type(model_file_type)
291
+ else:
292
+ plugin_id = "mlx"
293
+
294
+ # Determine ModelType with priority: kwargs > downloaded_manifest > pipeline_tag mapping
295
+ model_type = kwargs.get('model_type')
296
+ if not model_type:
297
+ if downloaded_manifest.get('ModelType'):
298
+ model_type = downloaded_manifest.get('ModelType')
299
+ else:
300
+ model_type = PIPELINE_TO_MODEL_TYPE.get(old_metadata.get('pipeline_tag'), "other")
301
+
302
+ # Determine ModelName with priority: kwargs > downloaded_manifest > empty string
303
+ model_name = kwargs.get('model_name')
304
+ if not model_name:
305
+ model_name = downloaded_manifest.get('ModelName', '')
306
+
307
+ # Get DeviceId and MinSDKVersion from kwargs or default to empty string
308
+ device_id = kwargs.get('device_id', '')
309
+ min_sdk_version = kwargs.get('min_sdk_version', '')
310
+
311
+ manifest = {
312
+ "Name": repo_id,
313
+ "ModelName": model_name,
314
+ "ModelType": model_type,
315
+ "PluginId": plugin_id,
316
+ "DeviceId": device_id,
317
+ "MinSDKVersion": min_sdk_version,
318
+ "ModelFile": model_files,
319
+ "MMProjFile": mmproj_file,
320
+ "TokenizerFile": {
321
+ "Name": "",
322
+ "Downloaded": False,
323
+ "Size": 0
324
+ },
325
+ "ExtraFiles": extra_files if extra_files else None,
326
+ # Preserve old metadata fields
327
+ "pipeline_tag": old_metadata.get('pipeline_tag') if old_metadata.get('pipeline_tag') else existing_manifest.get('pipeline_tag'),
328
+ "download_time": old_metadata.get('download_time') if old_metadata.get('download_time') else existing_manifest.get('download_time'),
329
+ "avatar_url": old_metadata.get('avatar_url') if old_metadata.get('avatar_url') else existing_manifest.get('avatar_url')
330
+ }
331
+
332
+ return manifest
333
+
334
+
335
+ def create_npu_manifest(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> Dict[str, Any]:
336
+ """Create NPU format manifest."""
337
+
338
+ # Load existing manifest to merge NPU files if it exists
339
+ existing_manifest = load_nexa_manifest(directory_path)
340
+
341
+ # Check if there's a downloaded nexa.manifest from the repo
342
+ downloaded_manifest = old_metadata.get('downloaded_manifest', {})
343
+
344
+ model_files = {}
345
+ extra_files = []
346
+
347
+ # Initialize MMProjFile
348
+ mmproj_file = {
349
+ "Name": "",
350
+ "Downloaded": False,
351
+ "Size": 0
352
+ }
353
+
354
+ for current_file_name in files:
355
+ file_path = os.path.join(directory_path, current_file_name)
356
+ file_size = 0
357
+ if os.path.exists(file_path):
358
+ try:
359
+ file_size = os.path.getsize(file_path)
360
+ except (OSError, IOError):
361
+ pass
362
+
363
+ # Check if this file is an mmproj file
364
+ is_current_mmproj = 'mmproj' in current_file_name.lower()
365
+
366
+ # If we're downloading specific files and this is marked as mmproj, respect that
367
+ if is_mmproj and file_name is not None:
368
+ filenames_to_check = file_name if isinstance(file_name, list) else [file_name]
369
+ is_current_mmproj = current_file_name in filenames_to_check
370
+
371
+ if is_current_mmproj:
372
+ # This is an mmproj file, put it in MMProjFile
373
+ mmproj_file = {
374
+ "Name": current_file_name,
375
+ "Downloaded": True,
376
+ "Size": file_size
377
+ }
378
+ else:
379
+ # For NPU, all non-mmproj files go to extra_files
380
+ extra_files.append({
381
+ "Name": current_file_name,
382
+ "Downloaded": True,
383
+ "Size": file_size
384
+ })
385
+
386
+ # Pick the first file from extra_files and add it to ModelFile with key "N/A"
387
+ if extra_files:
388
+ first_file = extra_files[0]
389
+ model_files["N/A"] = {
390
+ "Name": first_file["Name"],
391
+ "Downloaded": first_file["Downloaded"],
392
+ "Size": first_file["Size"]
393
+ }
394
+
395
+ # Determine PluginId with priority: kwargs > downloaded_manifest > model_file_type > default
396
+ plugin_id = kwargs.get('plugin_id')
397
+ if not plugin_id:
398
+ model_file_type = old_metadata.get('model_file_type')
399
+ if downloaded_manifest.get('PluginId'):
400
+ plugin_id = downloaded_manifest.get('PluginId')
401
+ elif model_file_type:
402
+ plugin_id = _get_plugin_id_from_model_file_type(model_file_type)
403
+ else:
404
+ plugin_id = "npu"
405
+
406
+ # Determine ModelType with priority: kwargs > downloaded_manifest > pipeline_tag mapping
407
+ model_type = kwargs.get('model_type')
408
+ if not model_type:
409
+ if downloaded_manifest.get('ModelType'):
410
+ model_type = downloaded_manifest.get('ModelType')
411
+ else:
412
+ model_type = PIPELINE_TO_MODEL_TYPE.get(old_metadata.get('pipeline_tag'), "other")
413
+
414
+ # Determine ModelName with priority: kwargs > downloaded_manifest > empty string
415
+ model_name = kwargs.get('model_name')
416
+ if not model_name:
417
+ model_name = downloaded_manifest.get('ModelName', '')
418
+
419
+ # Get DeviceId and MinSDKVersion from kwargs or default to empty string
420
+ device_id = kwargs.get('device_id', '')
421
+ min_sdk_version = kwargs.get('min_sdk_version', '')
422
+
236
423
  manifest = {
237
424
  "Name": repo_id,
238
- "ModelType": PIPELINE_TO_MODEL_TYPE.get(old_metadata.get('pipeline_tag'), "other"),
239
- "PluginId": "mlx",
425
+ "ModelName": model_name,
426
+ "ModelType": model_type,
427
+ "PluginId": plugin_id,
428
+ "DeviceId": device_id,
429
+ "MinSDKVersion": min_sdk_version,
240
430
  "ModelFile": model_files,
241
431
  "MMProjFile": mmproj_file,
242
432
  "TokenizerFile": {
@@ -254,8 +444,21 @@ def create_mlx_manifest(repo_id: str, files: List[str], directory_path: str, old
254
444
  return manifest
255
445
 
256
446
 
257
- def detect_model_type(files: List[str]) -> str:
258
- """Detect if this is a GGUF or MLX model based on file extensions."""
447
+ def detect_model_type(files: List[str], old_metadata: Dict[str, Any] = None) -> str:
448
+ """Detect if this is a GGUF, MLX, or NPU model based on file extensions and metadata.
449
+
450
+ Args:
451
+ files: List of files in the model directory
452
+ old_metadata: Metadata dict that may contain 'model_file_type'
453
+
454
+ Returns:
455
+ Model type string: 'gguf', 'mlx', or 'npu'
456
+ """
457
+ # Check if model_file_type is explicitly set to NPU
458
+ if old_metadata and old_metadata.get('model_file_type') == 'npu':
459
+ return "npu"
460
+
461
+ # Otherwise, detect based on file extensions
259
462
  has_gguf = any(f.endswith('.gguf') for f in files)
260
463
  has_safetensors = any(f.endswith('.safetensors') or 'safetensors' in f for f in files)
261
464
 
@@ -268,7 +471,7 @@ def detect_model_type(files: List[str]) -> str:
268
471
  return "mlx"
269
472
 
270
473
 
271
- def create_manifest_from_files(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None) -> Dict[str, Any]:
474
+ def create_manifest_from_files(repo_id: str, files: List[str], directory_path: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> Dict[str, Any]:
272
475
  """
273
476
  Create appropriate manifest format based on detected model type.
274
477
 
@@ -276,22 +479,25 @@ def create_manifest_from_files(repo_id: str, files: List[str], directory_path: s
276
479
  repo_id: Repository ID
277
480
  files: List of files in the model directory
278
481
  directory_path: Path to the model directory
279
- old_metadata: Existing metadata (pipeline_tag, download_time, avatar_url)
482
+ old_metadata: Existing metadata (pipeline_tag, download_time, avatar_url, model_file_type)
280
483
  is_mmproj: Whether the downloaded file is an mmproj file
281
484
  file_name: The specific file(s) that were downloaded (None if entire repo was downloaded)
485
+ **kwargs: Additional metadata including plugin_id, model_name, model_type, device_id, min_sdk_version
282
486
 
283
487
  Returns:
284
488
  Dict containing the appropriate manifest format
285
489
  """
286
- model_type = detect_model_type(files)
490
+ model_type = detect_model_type(files, old_metadata)
287
491
 
288
492
  if model_type == "gguf":
289
- return create_gguf_manifest(repo_id, files, directory_path, old_metadata, is_mmproj, file_name)
493
+ return create_gguf_manifest(repo_id, files, directory_path, old_metadata, is_mmproj, file_name, **kwargs)
494
+ elif model_type == "npu":
495
+ return create_npu_manifest(repo_id, files, directory_path, old_metadata, is_mmproj, file_name, **kwargs)
290
496
  else: # mlx or other
291
- return create_mlx_manifest(repo_id, files, directory_path, old_metadata, is_mmproj, file_name)
497
+ return create_mlx_manifest(repo_id, files, directory_path, old_metadata, is_mmproj, file_name, **kwargs)
292
498
 
293
499
 
294
- def save_manifest_with_files_metadata(repo_id: str, local_dir: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None) -> None:
500
+ def save_manifest_with_files_metadata(repo_id: str, local_dir: str, old_metadata: Dict[str, Any], is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> None:
295
501
  """
296
502
  Create and save manifest based on files found in the directory.
297
503
 
@@ -301,6 +507,7 @@ def save_manifest_with_files_metadata(repo_id: str, local_dir: str, old_metadata
301
507
  old_metadata: Existing metadata to preserve
302
508
  is_mmproj: Whether the downloaded file is an mmproj file
303
509
  file_name: The specific file(s) that were downloaded (None if entire repo was downloaded)
510
+ **kwargs: Additional metadata including plugin_id, model_name, model_type, device_id, min_sdk_version
304
511
  """
305
512
  # Get list of files in the directory
306
513
  files = []
@@ -314,7 +521,7 @@ def save_manifest_with_files_metadata(repo_id: str, local_dir: str, old_metadata
314
521
  pass
315
522
 
316
523
  # Create appropriate manifest
317
- manifest = create_manifest_from_files(repo_id, files, local_dir, old_metadata, is_mmproj, file_name)
524
+ manifest = create_manifest_from_files(repo_id, files, local_dir, old_metadata, is_mmproj, file_name, **kwargs)
318
525
 
319
526
  # Save manifest
320
527
  save_download_metadata(local_dir, manifest)