xinference 1.6.0.post1__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +1 -1
  3. xinference/conftest.py +0 -7
  4. xinference/core/media_interface.py +9 -8
  5. xinference/core/model.py +13 -6
  6. xinference/core/scheduler.py +1 -10
  7. xinference/core/worker.py +0 -10
  8. xinference/model/audio/model_spec.json +53 -1
  9. xinference/model/audio/model_spec_modelscope.json +57 -1
  10. xinference/model/embedding/core.py +19 -11
  11. xinference/model/image/model_spec.json +10 -1
  12. xinference/model/image/model_spec_modelscope.json +20 -0
  13. xinference/model/llm/__init__.py +6 -54
  14. xinference/model/llm/core.py +19 -5
  15. xinference/model/llm/llama_cpp/core.py +59 -3
  16. xinference/model/llm/llama_cpp/memory.py +455 -0
  17. xinference/model/llm/llm_family.json +185 -397
  18. xinference/model/llm/llm_family.py +88 -16
  19. xinference/model/llm/llm_family_modelscope.json +199 -421
  20. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  21. xinference/model/llm/sglang/core.py +4 -0
  22. xinference/model/llm/transformers/__init__.py +27 -6
  23. xinference/model/llm/transformers/chatglm.py +4 -2
  24. xinference/model/llm/transformers/core.py +49 -28
  25. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  26. xinference/model/llm/transformers/gemma3.py +119 -164
  27. xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
  28. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  29. xinference/model/llm/transformers/multimodal/core.py +205 -0
  30. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  31. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  32. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  33. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  34. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  35. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  36. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  37. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  38. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  39. xinference/model/llm/transformers/opt.py +4 -2
  40. xinference/model/llm/transformers/utils.py +6 -37
  41. xinference/model/llm/vllm/core.py +4 -0
  42. xinference/model/rerank/core.py +7 -1
  43. xinference/model/rerank/utils.py +17 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
  47. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
  52. xinference/web/ui/src/locales/en.json +3 -1
  53. xinference/web/ui/src/locales/zh.json +3 -1
  54. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/METADATA +6 -4
  55. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
  56. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
  57. xinference/model/llm/transformers/cogvlm2.py +0 -442
  58. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  59. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  60. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  61. xinference/model/llm/transformers/intern_vl.py +0 -526
  62. xinference/model/llm/transformers/internlm2.py +0 -94
  63. xinference/model/llm/transformers/minicpmv25.py +0 -193
  64. xinference/model/llm/transformers/omnilmm.py +0 -132
  65. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  66. xinference/model/llm/transformers/qwen_vl.py +0 -360
  67. xinference/thirdparty/omnilmm/LICENSE +0 -201
  68. xinference/thirdparty/omnilmm/__init__.py +0 -0
  69. xinference/thirdparty/omnilmm/chat.py +0 -218
  70. xinference/thirdparty/omnilmm/constants.py +0 -4
  71. xinference/thirdparty/omnilmm/conversation.py +0 -332
  72. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  73. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  74. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  75. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  76. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  77. xinference/thirdparty/omnilmm/utils.py +0 -134
  78. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  79. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  84. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
  85. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
  86. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
  87. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ import concurrent.futures
15
15
  import importlib.util
16
16
  import logging
17
17
  import os
18
+ import pprint
18
19
  import queue
19
20
  from typing import Iterator, List, Optional, Union
20
21
 
@@ -24,6 +25,7 @@ from ....types import ChatCompletion, ChatCompletionChunk, Completion, Completio
24
25
  from ..core import LLM
25
26
  from ..llm_family import LLMFamilyV1, LLMSpecV1
26
27
  from ..utils import ChatModelMixin
28
+ from .memory import estimate_gpu_layers
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -95,7 +97,12 @@ class XllamaCppModel(LLM, ChatModelMixin):
95
97
 
96
98
  def load(self):
97
99
  try:
98
- from xllamacpp import CommonParams, Server
100
+ from xllamacpp import (
101
+ CommonParams,
102
+ Server,
103
+ get_device_info,
104
+ ggml_backend_dev_type,
105
+ )
99
106
  except ImportError:
100
107
  error_message = "Failed to import module 'xllamacpp'"
101
108
  installation_guide = ["Please make sure 'xllamacpp' is installed. "]
@@ -135,6 +142,15 @@ class XllamaCppModel(LLM, ChatModelMixin):
135
142
  if os.path.exists(legacy_model_file_path):
136
143
  model_path = legacy_model_file_path
137
144
 
145
+ multimodal_projector = self._llamacpp_model_config.get(
146
+ "multimodal_projector", ""
147
+ )
148
+ mmproj = (
149
+ os.path.join(self.model_path, multimodal_projector)
150
+ if multimodal_projector
151
+ else ""
152
+ )
153
+
138
154
  try:
139
155
  params = CommonParams()
140
156
  # Compatible with xllamacpp changes
@@ -142,6 +158,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
142
158
  params.model = model_path
143
159
  except Exception:
144
160
  params.model.path = model_path
161
+ params.mmproj.path = mmproj
145
162
  if self.model_family.chat_template:
146
163
  params.chat_template = self.model_family.chat_template
147
164
  # This is the default value, could be overwritten by _llamacpp_model_config
@@ -165,6 +182,41 @@ class XllamaCppModel(LLM, ChatModelMixin):
165
182
  # Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
166
183
  # 0x7FFFFFFF is INT32 max, will be auto set to all layers
167
184
  params.n_gpu_layers = 0x7FFFFFFF
185
+ try:
186
+ device_info = get_device_info()
187
+ gpus = [
188
+ info
189
+ for info in device_info
190
+ if info["type"]
191
+ == ggml_backend_dev_type.GGML_BACKEND_DEVICE_TYPE_GPU
192
+ ]
193
+ if gpus:
194
+ logger.info(
195
+ "Try to estimate num gpu layers, n_ctx: %s, n_batch: %s, n_parallel: %s, gpus:\n%s",
196
+ params.n_ctx,
197
+ params.n_batch,
198
+ params.n_parallel,
199
+ pprint.pformat(gpus),
200
+ )
201
+ estimate = estimate_gpu_layers(
202
+ gpus=gpus,
203
+ model_path=model_path,
204
+ projectors=[mmproj] if mmproj else [],
205
+ context_length=params.n_ctx,
206
+ batch_size=params.n_batch,
207
+ num_parallel=params.n_parallel,
208
+ kv_cache_type="",
209
+ )
210
+ logger.info("Estimate num gpu layers: %s", estimate)
211
+ if estimate.tensor_split:
212
+ params.tensor_split = estimate.tensor_split
213
+ else:
214
+ params.n_gpu_layers = estimate.layers
215
+ except Exception as e:
216
+ logger.exception(
217
+ "Estimate num gpu layers for llama.cpp backend failed: %s", e
218
+ )
219
+
168
220
  self._llm = Server(params)
169
221
  self._executor = concurrent.futures.ThreadPoolExecutor(
170
222
  max_workers=max(10, n_threads)
@@ -207,11 +259,13 @@ class XllamaCppModel(LLM, ChatModelMixin):
207
259
  q.put(res)
208
260
  except Exception as e:
209
261
  logger.exception("handle_completions callback failed: %s", e)
262
+ q.put(_Error(str(e)))
210
263
 
211
264
  try:
212
265
  self._llm.handle_completions(prompt_json, _error_callback, _ok_callback)
213
266
  except Exception as ex:
214
267
  logger.exception("handle_completions failed: %s", ex)
268
+ q.put(_Error(str(ex)))
215
269
  q.put(_Done)
216
270
 
217
271
  assert self._executor
@@ -271,6 +325,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
271
325
  q.put(res)
272
326
  except Exception as e:
273
327
  logger.exception("handle_chat_completions callback failed: %s", e)
328
+ q.put(_Error(str(e)))
274
329
 
275
330
  try:
276
331
  self._llm.handle_chat_completions(
@@ -278,6 +333,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
278
333
  )
279
334
  except Exception as ex:
280
335
  logger.exception("handle_chat_completions failed: %s", ex)
336
+ q.put(_Error(str(ex)))
281
337
  q.put(_Done)
282
338
 
283
339
  assert self._executor
@@ -288,7 +344,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
288
344
  def _to_iterator():
289
345
  while (r := q.get()) is not _Done:
290
346
  if type(r) is _Error:
291
- raise Exception("Got error in chat stream: %s", r.msg)
347
+ raise Exception(f"Got error in chat stream: {r.msg}")
292
348
  # Get valid keys (O(1) lookup)
293
349
  chunk_keys = ChatCompletionChunk.__annotations__
294
350
  # The chunk may contain additional keys (e.g., system_fingerprint),
@@ -302,5 +358,5 @@ class XllamaCppModel(LLM, ChatModelMixin):
302
358
  else:
303
359
  r = q.get()
304
360
  if type(r) is _Error:
305
- raise Exception("Got error in chat: %s", r.msg)
361
+ raise Exception(f"Got error in chat: {r.msg}")
306
362
  return self._to_chat_completion(r, self.reasoning_parser)
@@ -0,0 +1,455 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from collections.abc import Sequence
17
+ from dataclasses import dataclass
18
+ from typing import Any
19
+
20
+ from gguf import GGUFReader, GGUFValueType # noqa: E402
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
26
+ file_endian = reader.endianess.name # codespell:ignore
27
+ if reader.byte_order == "S":
28
+ host_endian = "BIG" if file_endian == "LITTLE" else "LITTLE"
29
+ else:
30
+ host_endian = file_endian
31
+ return (host_endian, file_endian)
32
+
33
+
34
+ def dump_metadata_json(reader: GGUFReader, model_path: str) -> dict:
35
+ host_endian, file_endian = get_file_host_endian(reader)
36
+ metadata: dict[str, Any] = {}
37
+ tensors: dict[str, Any] = {}
38
+ result = {
39
+ "filename": model_path,
40
+ "endian": file_endian,
41
+ "metadata": metadata,
42
+ "tensors": tensors,
43
+ }
44
+ for idx, field in enumerate(reader.fields.values()):
45
+ curr: dict[str, Any] = {
46
+ "index": idx,
47
+ "type": field.types[0].name if field.types else "UNKNOWN",
48
+ "offset": field.offset,
49
+ }
50
+ metadata[field.name] = curr
51
+ if field.types[:1] == [GGUFValueType.ARRAY]:
52
+ curr["array_types"] = [t.name for t in field.types][1:]
53
+ curr["value"] = field.contents()
54
+ else:
55
+ curr["value"] = field.contents()
56
+ for i, tensor in enumerate(reader.tensors):
57
+ tensors[tensor.name] = {
58
+ "index": i,
59
+ "shape": tensor.shape.tolist(),
60
+ "type": tensor.tensor_type.name,
61
+ "offset": tensor.field.offset,
62
+ "n_bytes": tensor.n_bytes,
63
+ }
64
+ return result
65
+
66
+
67
+ @dataclass
68
+ class MemoryEstimate:
69
+ # How many layers we predict we can load
70
+ layers: int
71
+ # The size of the graph which occupies the main GPU
72
+ graph: int
73
+ # How much VRAM will be allocated given the number of layers we predict
74
+ vram_size: int
75
+ # The total size of the model if loaded into VRAM. If all layers are loaded, vram_size == total_size
76
+ total_size: int
77
+ # For multi-GPU scenarios, this provides the tensor split parameter
78
+ tensor_split: str
79
+ # For multi-GPU scenarios, this is the size in bytes per GPU
80
+ gpu_sizes: list[int]
81
+
82
+
83
+ def _get_max_min(value):
84
+ if isinstance(value, Sequence):
85
+ return max(value), min(value)
86
+ else:
87
+ return value, value
88
+
89
+
90
+ def graph_size(
91
+ data: dict,
92
+ context_length: int,
93
+ batch_size: int,
94
+ num_parallel: int,
95
+ kv_cache_type: str,
96
+ ):
97
+ """
98
+ Most of the logic comes from `GraphSize` in https://github.com/ollama/ollama/blob/main/fs/ggml/ggml.go
99
+ """
100
+ if context_length < batch_size:
101
+ batch_size = context_length
102
+
103
+ metadata = data["metadata"]
104
+ architecture = metadata["general.architecture"]["value"]
105
+ embedding_length = metadata[f"{architecture}.embedding_length"]["value"]
106
+ block_count = metadata[f"{architecture}.block_count"]["value"]
107
+ head_count_max, head_count_min = _get_max_min(
108
+ metadata[f"{architecture}.attention.head_count"]["value"]
109
+ )
110
+ head_count_kv_max, head_count_kv_min = _get_max_min(
111
+ metadata[f"{architecture}.attention.head_count_kv"]["value"]
112
+ )
113
+ vocab = len(metadata["tokenizer.ggml.tokens"]["value"])
114
+ embedding_head_count_max = (
115
+ (embedding_length // head_count_min) if head_count_min > 0 else 0
116
+ )
117
+ embedding_head_count_k = metadata.get(
118
+ f"{architecture}.attention.key_length", {}
119
+ ).get("value", embedding_head_count_max)
120
+ embedding_head_count_v = metadata.get(
121
+ f"{architecture}.attention.value_length", {}
122
+ ).get("value", embedding_head_count_max)
123
+
124
+ # f16(default)
125
+ bytes_per_kv_element = {
126
+ "q8_0": 1, # 1/2 of fp16
127
+ "q4_0": 0.5, # 1/4 of fp16
128
+ }.get(kv_cache_type, 2)
129
+
130
+ kv = [0] * block_count
131
+ for i in range(block_count):
132
+ kv[i] = (
133
+ context_length
134
+ * (embedding_head_count_k + embedding_head_count_v)
135
+ * head_count_kv_max
136
+ * bytes_per_kv_element
137
+ )
138
+
139
+ full_offload = 0
140
+ partial_offload = 0
141
+ if architecture in ["llama", "llama4"]:
142
+ full_offload = max(
143
+ 4
144
+ * batch_size
145
+ * (1 + 4 * embedding_length + context_length * (1 + head_count_max)),
146
+ 4 * batch_size * (embedding_length + vocab),
147
+ )
148
+ partial_offload = 4 * batch_size * embedding_length
149
+ partial_offload += max(
150
+ 4
151
+ * batch_size
152
+ * (1 + embedding_length + max(context_length, embedding_length))
153
+ + embedding_length * embedding_length * 9 / 16
154
+ + 4
155
+ * context_length
156
+ * (
157
+ batch_size * head_count_max
158
+ + embedding_head_count_max * head_count_kv_max
159
+ ),
160
+ 4 * batch_size * (embedding_length + vocab)
161
+ + embedding_length * vocab * 105 / 128,
162
+ )
163
+ elif architecture in ["gemma", "gemma2", "gemma3"]:
164
+ full_offload = max(
165
+ 4 * batch_size * (embedding_length + vocab),
166
+ 4
167
+ * batch_size
168
+ * (
169
+ 2
170
+ + context_length
171
+ + context_length * head_count_max
172
+ + 2 * embedding_length
173
+ + 2 * embedding_head_count_k * head_count_max
174
+ ),
175
+ )
176
+ partial_offload = max(
177
+ 4 * embedding_length * batch_size
178
+ + embedding_length * vocab * 105 / 128
179
+ + 4 * vocab * batch_size,
180
+ 4
181
+ * batch_size
182
+ * (
183
+ 2 * embedding_length
184
+ + 1
185
+ + 2 * embedding_head_count_k * head_count_max
186
+ + context_length
187
+ + context_length * head_count_max
188
+ )
189
+ + 4 * embedding_head_count_k * context_length * 8
190
+ + embedding_length * embedding_head_count_k * head_count_max * 9 / 16,
191
+ )
192
+ if architecture == "gemma3":
193
+ gemma3_global_cache_count = 6
194
+ sliding_window = (
195
+ num_parallel
196
+ * metadata[f"{architecture}.attention.sliding_window"]["value"]
197
+ + batch_size
198
+ )
199
+ for i in range(block_count):
200
+ if (i + 1) % gemma3_global_cache_count != 0:
201
+ kv[i] = (
202
+ sliding_window
203
+ * (embedding_head_count_k + embedding_head_count_v)
204
+ * head_count_kv_max
205
+ * bytes_per_kv_element
206
+ )
207
+ elif architecture == "qwen2":
208
+ full_offload = max(
209
+ 4 * batch_size * (embedding_length + vocab),
210
+ 4
211
+ * batch_size
212
+ * (
213
+ 1
214
+ + 2 * embedding_length
215
+ + context_length
216
+ + context_length * head_count_max
217
+ ),
218
+ )
219
+
220
+ partial_offload = max(
221
+ 4 * batch_size * (embedding_length + vocab)
222
+ + embedding_length * vocab * 105 / 128,
223
+ 4
224
+ * (
225
+ batch_size
226
+ * (1 + 2 * embedding_length + context_length * (1 + head_count_max))
227
+ + embedding_length * (1 + context_length)
228
+ ),
229
+ )
230
+ elif architecture == "stablelm":
231
+ full_offload = (
232
+ 4
233
+ * batch_size
234
+ * (context_length * (1 + head_count_max) + 3 * embedding_length + 2)
235
+ )
236
+ partial_offload = max(
237
+ 4 * batch_size * (vocab + 2 * embedding_length), full_offload
238
+ )
239
+ elif architecture == "deepseek2":
240
+ full_offload = max(
241
+ 4 * batch_size * (3 * embedding_length + vocab),
242
+ 4
243
+ * batch_size
244
+ * (
245
+ 3 * embedding_length
246
+ + 2
247
+ + context_length * (1 + head_count_kv_max)
248
+ + 2 * embedding_head_count_k * head_count_kv_max
249
+ ),
250
+ )
251
+
252
+ partial_offload = max(
253
+ 4 * batch_size * (3 * embedding_length + vocab)
254
+ + embedding_length * vocab * 105 / 128,
255
+ 4
256
+ * batch_size
257
+ * (
258
+ 2 * embedding_length
259
+ + 1
260
+ + 2 * embedding_head_count_k * head_count_kv_max
261
+ + context_length
262
+ + context_length * head_count_kv_max
263
+ )
264
+ + 4 * embedding_head_count_k * context_length * head_count_kv_max
265
+ + embedding_length * embedding_head_count_k * head_count_kv_max * 9 / 16,
266
+ )
267
+
268
+ kv_total = sum(kv)
269
+ if partial_offload == 0:
270
+ partial_offload = (
271
+ head_count_max
272
+ / (1 if head_count_kv_min <= 0 else head_count_kv_min)
273
+ * kv_total
274
+ / 6
275
+ )
276
+ if full_offload == 0:
277
+ full_offload = partial_offload
278
+
279
+ return kv, partial_offload, full_offload
280
+
281
+
282
+ def projector_memory_requirements(projector: str):
283
+ reader = GGUFReader(projector, "r")
284
+ data = dump_metadata_json(reader, projector)
285
+ return sum(t["n_bytes"] for t in data["tensors"].values())
286
+
287
+
288
+ def estimate_gpu_layers(
289
+ gpus: list[dict],
290
+ model_path: str,
291
+ projectors: list[str],
292
+ context_length: int,
293
+ batch_size: int,
294
+ num_parallel: int,
295
+ kv_cache_type: str,
296
+ ):
297
+ """
298
+ Most of the logic comes from `EstimateGPULayers` in https://github.com/ollama/ollama/blob/main/llm/memory.go
299
+ """
300
+ # Projectors loaded into GPU0 only
301
+ projector_weights = sum(map(projector_memory_requirements, projectors))
302
+ if projector_weights > 0:
303
+ # Multimodal models require at least 2048 context
304
+ context_length = max(context_length, 2048)
305
+ reader = GGUFReader(model_path, "r")
306
+ data = dump_metadata_json(reader, model_path)
307
+ kv, graph_partial_offload, graph_full_offload = graph_size(
308
+ data,
309
+ context_length=context_length,
310
+ batch_size=batch_size,
311
+ num_parallel=num_parallel,
312
+ kv_cache_type=kv_cache_type,
313
+ )
314
+ # Get all layer sizes
315
+ metadata = data["metadata"]
316
+ architecture = metadata["general.architecture"]["value"]
317
+ block_count = metadata[f"{architecture}.block_count"]["value"]
318
+ layer_sizes = [0] * block_count
319
+ for name, layer in data["tensors"].items():
320
+ if name.startswith("blk."):
321
+ index = int(name[len("blk.") :].split(".")[0])
322
+ layer_sizes[index] += layer["n_bytes"]
323
+ layer_size = layer_sizes[0] if layer_sizes else 0
324
+
325
+ if len(kv) > 0:
326
+ layer_size += kv[0]
327
+ # On metal there's no partial offload overhead
328
+ if gpus[0]["name"] == "Metal":
329
+ graph_partial_offload = graph_full_offload
330
+ elif len(gpus) > 1:
331
+ # Multi gpu should always use the partial graph size
332
+ graph_full_offload = graph_partial_offload
333
+
334
+ # Get output layer size
335
+ memory_layer_output = 0
336
+ # Output layer handled at the end if we have space
337
+ for name, layer in data["tensors"].items():
338
+ if any(
339
+ name.startswith(prefix)
340
+ for prefix in ["output_norm", "output", "token_embd"]
341
+ ):
342
+ memory_layer_output += layer["n_bytes"]
343
+
344
+ # Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
345
+ default_memory_min = 512 * 1024**2
346
+ gpu_allocations = [0] * len(gpus)
347
+ gpus_with_space: list[int] = []
348
+ for i in range(len(gpus)):
349
+ gpu0_overhead = projector_weights if len(gpus_with_space) == 0 else 0
350
+ minimum_memory = gpus[i].get("memory_min", default_memory_min)
351
+ if (
352
+ gpus[i]["memory_free"]
353
+ < gpu0_overhead
354
+ + max(graph_partial_offload, graph_full_offload)
355
+ + minimum_memory
356
+ + 2 * layer_size
357
+ ):
358
+ continue
359
+ gpus_with_space.append(i)
360
+ gpu_allocations[i] += gpu0_overhead + minimum_memory + layer_size
361
+
362
+ overflow = 0
363
+ if len(gpus_with_space) == 0:
364
+ overflow = projector_weights
365
+
366
+ # For all the layers, find where they can fit on the GPU(s)
367
+ layer_count = 0
368
+ layer_counts = [0] * len(gpus)
369
+ for i in range(block_count - 1, -1, -1):
370
+ layer_size = layer_sizes[i]
371
+ layer_size += kv[i]
372
+
373
+ # Distribute the layers across the GPU(s) that have space
374
+ for j in range(len(gpus_with_space), 0, -1):
375
+ g = gpus_with_space[i % j]
376
+ used = gpu_allocations[g] + max(graph_partial_offload, graph_full_offload)
377
+ if gpus[g]["memory_free"] > used + layer_size:
378
+ gpu_allocations[g] += layer_size
379
+ layer_counts[g] += 1
380
+ layer_count += 1
381
+ break
382
+ else:
383
+ gpus_with_space = (
384
+ gpus_with_space[: i % j] + gpus_with_space[i % j + 1 :]
385
+ )
386
+
387
+ if len(gpus_with_space) == 0:
388
+ overflow += layer_size
389
+
390
+ fully_loaded = False
391
+ if layer_count >= block_count:
392
+ fully_loaded = True
393
+
394
+ # Determine if we need to consider output then find where it fits
395
+ if memory_layer_output > 0:
396
+ for j in range(len(gpus_with_space), 0, -1):
397
+ g = gpus_with_space[layer_count % j]
398
+ used = gpu_allocations[g] + max(graph_partial_offload, graph_full_offload)
399
+ if gpus[g]["memory_free"] > used + memory_layer_output:
400
+ gpu_allocations[g] += memory_layer_output
401
+ layer_counts[g] += 1
402
+ layer_count += 1
403
+ break
404
+ else:
405
+ gpus_with_space = (
406
+ gpus_with_space[: layer_count % j]
407
+ + gpus_with_space[layer_count % j + 1 :]
408
+ )
409
+
410
+ if layer_count < block_count + 1:
411
+ fully_loaded = False
412
+ overflow += memory_layer_output
413
+
414
+ # Add the applicable (full or partial) graph allocations
415
+ for i in range(len(gpus)):
416
+ if layer_counts[i] <= 0:
417
+ continue
418
+ if fully_loaded:
419
+ gpu_allocations[i] += graph_full_offload
420
+ else:
421
+ gpu_allocations[i] += graph_partial_offload
422
+
423
+ if fully_loaded:
424
+ graph_offload = graph_full_offload
425
+ else:
426
+ graph_offload = graph_partial_offload
427
+
428
+ # Summaries
429
+ memory_required_partial = sum(gpu_allocations)
430
+ memory_required_total = memory_required_partial + overflow
431
+
432
+ tensor_split = ""
433
+ if len(gpus) > 1:
434
+ tensor_split = ",".join(str(c) for c in layer_counts)
435
+
436
+ estimate = MemoryEstimate(
437
+ layers=0,
438
+ graph=0,
439
+ vram_size=0,
440
+ total_size=int(memory_required_total),
441
+ tensor_split="",
442
+ gpu_sizes=[],
443
+ )
444
+ if gpus[0]["name"] == "CPU":
445
+ return estimate
446
+ if layer_count == 0:
447
+ return estimate
448
+
449
+ estimate.layers = layer_count
450
+ estimate.graph = int(graph_offload)
451
+ estimate.vram_size = int(memory_required_partial)
452
+ estimate.total_size = int(memory_required_total)
453
+ estimate.tensor_split = tensor_split
454
+ estimate.gpu_sizes = [int(i) for i in gpu_allocations]
455
+ return estimate