tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
  49. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ import functools
2
+ import os
3
+ from typing import Any, Optional
4
+
5
+ import jax
6
+ import torch
7
+ from flax import nnx
8
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
9
+ from torchax.ops.mappings import j2t_dtype
10
+ from transformers import PretrainedConfig
11
+ from vllm.config import VllmConfig
12
+ from vllm.utils import supports_kw
13
+
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.utils.quantization.quantization_utils import (
16
+ apply_qwix_on_abstract_model, apply_qwix_quantization,
17
+ load_random_weights_into_qwix_abstract_model)
18
+
19
+ logger = init_logger(__name__)
20
+
21
+ _MODEL_REGISTRY = {}
22
+
23
+
24
+ class UnsupportedArchitectureError(ValueError):
25
+ """Raised when a model architecture is not supported in the registry."""
26
+ pass
27
+
28
+
29
+ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
30
+ # NOTE: Use inline imports here, otherwise the normal imports
31
+ # would cause JAX init failure when using multi hosts with Ray.
32
+
33
+ from tpu_inference.models.jax.deepseek_v3 import DeepSeekV3
34
+ from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
35
+ from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
36
+ from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
37
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
38
+ from tpu_inference.models.jax.qwen2_5_vl import \
39
+ Qwen2_5_VLForConditionalGeneration
40
+ from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
41
+
42
+ if os.getenv("NEW_MODEL_DESIGN", False):
43
+ from tpu_inference.experimental.llama3_jax_stashed import \
44
+ LlamaForCausalLM
45
+ else:
46
+ from tpu_inference.models.jax.llama3 import LlamaForCausalLM
47
+
48
+ _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
49
+ _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
50
+ _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
51
+ _MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
52
+ _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
53
+ _MODEL_REGISTRY[
54
+ "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
55
+ _MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
56
+ _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
57
+
58
+ architectures = getattr(config, "architectures", [])
59
+ for arch in architectures:
60
+ if arch in _MODEL_REGISTRY:
61
+ return _MODEL_REGISTRY[arch]
62
+ raise UnsupportedArchitectureError(
63
+ f"Model architectures {architectures} are not supported for now. "
64
+ f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
65
+
66
+
67
+ def _get_nnx_model(
68
+ model_class: Any,
69
+ vllm_config: VllmConfig,
70
+ rng: jax.Array,
71
+ mesh: Mesh,
72
+ ) -> nnx.Module:
73
+
74
+ def create_abstract_model() -> nnx.Module:
75
+ """
76
+ Helper class to create an abstract model for `nnx.eval_shape`.
77
+
78
+ Returns:
79
+ An abstract model function.
80
+ """
81
+ return model_class(vllm_config, rng, mesh)
82
+
83
+ @nnx.jit(donate_argnums=(0, ),
84
+ static_argnames=('use_qwix_on_abstract_model', ))
85
+ def create_jit_model(
86
+ model: nnx.Module,
87
+ use_qwix_on_abstract_model: bool = False) -> nnx.Module:
88
+ """
89
+ Create a jit model.
90
+
91
+ Args:
92
+ model: The model to jit.
93
+ use_qwix_on_abstract_model: Whether to apply Qwix on the abstract model.
94
+
95
+ Returns:
96
+ The jitted model.
97
+ """
98
+ state = nnx.state(model)
99
+ nnx.update(model, state)
100
+ if not use_qwix_on_abstract_model:
101
+ # NOTE: if Qwix is not configured, this will be a no-op
102
+ model = apply_qwix_quantization(vllm_config,
103
+ model,
104
+ rng,
105
+ mesh,
106
+ apply_to_abstract_model=False)
107
+ return model
108
+
109
+ if os.getenv("JAX_RANDOM_WEIGHTS", False):
110
+ # Create a sharded model with random inited weights.
111
+ # TODO: currently Qwen2ForCausalLM is using legacy model implementation
112
+ # will merge the random init logic when all model are migrated to new model implementation
113
+
114
+ # Handle the case where we want to load in random weights to a Qwix-quantized model. Here, we
115
+ # need to run an abstract pass for Qwix first and then load in the random weights.
116
+ if apply_qwix_on_abstract_model(vllm_config):
117
+ abstract_model_fn = apply_qwix_quantization(
118
+ vllm_config,
119
+ create_abstract_model,
120
+ rng,
121
+ mesh,
122
+ apply_to_abstract_model=True)
123
+
124
+ model = nnx.eval_shape(abstract_model_fn)
125
+ quantization_config = vllm_config.model_config.hf_config.quantization_config if hasattr(
126
+ vllm_config.model_config.hf_config,
127
+ "quantization_config") else {}
128
+ load_random_weights_into_qwix_abstract_model(
129
+ rng, model, mesh, quantization_config)
130
+ with mesh:
131
+ jit_model = create_jit_model(model,
132
+ use_qwix_on_abstract_model=True)
133
+ return jit_model
134
+
135
+ @nnx.jit
136
+ def create_sharded_model():
137
+ model = model_class(vllm_config, rng, mesh)
138
+ state = nnx.state(model)
139
+ pspecs = nnx.get_partition_spec(state)
140
+ sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
141
+ nnx.update(model, sharded_state)
142
+ # NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
143
+ return model
144
+
145
+ with mesh:
146
+ jit_model = create_sharded_model()
147
+ # In this case, we are applying Qwix quantization to the true, concrete model
148
+ jit_model = apply_qwix_quantization(vllm_config,
149
+ jit_model,
150
+ rng,
151
+ mesh,
152
+ apply_to_abstract_model=False)
153
+ if hasattr(jit_model, 'initialize_cache'):
154
+ jit_model.initialize_cache()
155
+ else:
156
+ # We first create an abstract model without allocating any weights,
157
+ # then fill in its weigths during load_weights from HF.
158
+ # This shows 2 advantages than the normal way:
159
+ # 1. The model weights will only be allocated once. Otherwise the normal way
160
+ # will random-init the model weights first, then load the real weights.
161
+ # The two pass weights allocation causes model loading slow.
162
+ # 2. The model loading won't be OOM. Otherwise the normal way will hold
163
+ # a full model weights after random-init, then duplicate a layer during
164
+ # the load_weights. This would be easy to OOM if the layer is super large.
165
+ abstract_model_fn = create_abstract_model
166
+ # NOTE: only one of the abstract (this) or or concrete Qwix quantization paths should
167
+ # be taken
168
+ if should_apply_qwix_on_abstract_model := apply_qwix_on_abstract_model(
169
+ vllm_config):
170
+ # NOTE: if Qwix is not configured, this will return `create_abstract_model` and
171
+ # thus be a no-op
172
+ abstract_model_fn = apply_qwix_quantization(
173
+ vllm_config,
174
+ create_abstract_model,
175
+ rng,
176
+ mesh,
177
+ apply_to_abstract_model=True)
178
+ model = nnx.eval_shape(abstract_model_fn)
179
+ # Although the created model can already work, we still need to jit
180
+ # the model creation again, otherwise the model forward will have
181
+ # non-trivial overhead in PjitFunction.
182
+ with mesh:
183
+ model.load_weights(rng)
184
+ jit_model = create_jit_model(
185
+ model,
186
+ use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
187
+ return jit_model
188
+
189
+
190
+ # TODO(pooyam): We need to refactor this. This is returning a bunch of functions that do not work with all models and this is not very easy to see from the code.
191
+ def get_flax_model(
192
+ vllm_config: VllmConfig,
193
+ rng: jax.Array,
194
+ mesh: Mesh,
195
+ is_draft_model: bool = False,
196
+ ) -> nnx.Module:
197
+ if is_draft_model:
198
+ model_class = _get_model_architecture(
199
+ vllm_config.speculative_config.draft_model_config.hf_config)
200
+ else:
201
+ model_class = _get_model_architecture(
202
+ vllm_config.model_config.hf_config)
203
+ jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
204
+ kv_cache_sharding = NamedSharding(mesh, PartitionSpec(None, None, "model"))
205
+ hidden_states_sharding = NamedSharding(mesh, PartitionSpec(None,
206
+ None)) # (T, D)
207
+
208
+ # For performance consideration, refer to:
209
+ # https://flax.readthedocs.io/en/latest/guides/performance.html
210
+ graphdef, state = nnx.split(jit_model)
211
+
212
+ @functools.partial(
213
+ jax.jit,
214
+ out_shardings=(
215
+ kv_cache_sharding,
216
+ hidden_states_sharding,
217
+ hidden_states_sharding, # aux hidden states
218
+ ),
219
+ donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
+ static_argnums=6, #6 is layer_name_to_kvcache_index
221
+ )
222
+ def run_model(graphdef, state, *args):
223
+ model = nnx.merge(graphdef, state)
224
+ return model(*args)
225
+
226
+ logits_sharding = NamedSharding(mesh, PartitionSpec(None, "model"))
227
+
228
+ @functools.partial(
229
+ jax.jit,
230
+ out_shardings=(logits_sharding),
231
+ )
232
+ def run_compute_logits(graphdef, state, *args):
233
+ model = nnx.merge(graphdef, state)
234
+ hidden_state, *_ = args
235
+ return model.compute_logits(hidden_state)
236
+
237
+ # Multi-modal support only
238
+ # This function calculates the image token's embeddings by VIT
239
+ def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
240
+ **kwargs):
241
+ model = nnx.merge(graphdef, state)
242
+ return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
243
+
244
+ # This function will calculates the embeddings of input texts and then merge with the image embeddings
245
+ @functools.partial(
246
+ jax.jit,
247
+ out_shardings=(logits_sharding),
248
+ )
249
+ def run_get_input_embeddings(graphdef, state, *args, **kwargs):
250
+ model = nnx.merge(graphdef, state)
251
+ return model.get_input_embeddings(*args, **kwargs)
252
+
253
+ # For models that want to work with EAGLE-3 speculative decoding
254
+ @functools.partial(
255
+ jax.jit,
256
+ out_shardings=(logits_sharding),
257
+ )
258
+ def combine_hidden_states(graphdef, state, hidden_states):
259
+ model = nnx.merge(graphdef, state)
260
+ return model.combine_hidden_states(hidden_states)
261
+
262
+ model_fn = functools.partial(run_model, graphdef)
263
+ compute_logits_fn = functools.partial(run_compute_logits, graphdef)
264
+ get_multimodal_embeddings_fn = functools.partial(
265
+ run_get_multimodal_embeddings, graphdef)
266
+ get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
267
+ graphdef)
268
+ lora_manager, model = None, None
269
+ combine_hidden_states_fn = functools.partial(combine_hidden_states,
270
+ graphdef)
271
+
272
+ get_mrope_input_positions_fn = None if not hasattr(
273
+ model_class,
274
+ "get_mrope_input_positions") else model_class.get_mrope_input_positions
275
+
276
+ return model_fn, compute_logits_fn, combine_hidden_states_fn, get_multimodal_embeddings_fn, get_input_embeddings_fn, get_mrope_input_positions_fn, state, lora_manager, model
277
+
278
+
279
+ def get_vllm_model(
280
+ vllm_config: VllmConfig,
281
+ rng: jax.Array,
282
+ mesh: Mesh,
283
+ ):
284
+ from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
285
+
286
+ model = VllmModelWrapper(
287
+ vllm_config=vllm_config,
288
+ rng=rng,
289
+ mesh=mesh,
290
+ )
291
+ params, lora_manager = model.load_weights()
292
+
293
+ jit_model = model.jit_step_func()
294
+ compute_logits_fn = model.jit_compute_logits_func()
295
+ # the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
296
+ combine_hidden_states_fn = None
297
+ return jit_model, compute_logits_fn, combine_hidden_states_fn, None, None, None, params, lora_manager, model
298
+
299
+
300
+ def get_model(
301
+ vllm_config: VllmConfig,
302
+ rng: jax.Array,
303
+ mesh: Mesh,
304
+ is_draft_model: bool = False,
305
+ ) -> Any:
306
+ impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
307
+ logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
308
+
309
+ if impl == "flax_nnx":
310
+ try:
311
+ # Try to load the flax model first
312
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
313
+ except UnsupportedArchitectureError as e:
314
+ # Convert the error message to a string to check its contents
315
+ error_msg = str(e)
316
+
317
+ logger.warning(f"Flax model failed with: '{error_msg}'. "
318
+ "Falling back to vLLM implementation.")
319
+ # Fall back to the vLLM model and updating the dtype accordingly
320
+ vllm_config.model_config.dtype = j2t_dtype(
321
+ vllm_config.model_config.dtype.dtype)
322
+ return get_vllm_model(vllm_config, rng, mesh)
323
+ elif impl == "vllm":
324
+ return get_vllm_model(vllm_config, rng, mesh)
325
+ else:
326
+ raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
327
+
328
+
329
+ def _validate_model_interface(model: Any) -> None:
330
+ """Validates that the model class has the required methods and signatures.
331
+
332
+ A valid model must have:
333
+ - An __init__ method that accepts a 'vllm_config' keyword argument.
334
+ - A __call__ method that accepts 'kv_caches', 'input_ids', and
335
+ 'attention_metadata' keyword arguments.
336
+
337
+ Args:
338
+ model: The model class to validate.
339
+
340
+ Raises:
341
+ TypeError: If the model does not meet the interface requirements.
342
+ """
343
+ # Check for __init__ with vllm_config
344
+ model_init = getattr(model, "__init__", None)
345
+ if not callable(model_init):
346
+ raise TypeError(
347
+ f"Model {model.__name__} must have an __init__ method.")
348
+
349
+ if not supports_kw(model_init, "vllm_config"):
350
+ raise TypeError(
351
+ f"Model {model.__name__} __init__ method must accept a "
352
+ "'vllm_config' keyword argument.")
353
+
354
+ # Check for __call__ with required arguments
355
+ model_call = getattr(model, "__call__", None)
356
+ # A class object is always callable (it produces an instance).
357
+ # We need to check if the class _explicitly_ defines a __call__ method for its
358
+ # instance, which is different from `type.__call__`.
359
+ has_defined_call = False
360
+ if isinstance(model, type):
361
+ if any("__call__" in C.__dict__ for C in model.__mro__):
362
+ has_defined_call = True
363
+ elif callable(model_call):
364
+ # For an instance, a simple callable check is sufficient.
365
+ has_defined_call = True
366
+
367
+ if not has_defined_call:
368
+ raise TypeError(f"Model {model.__name__} must have a __call__ method.")
369
+
370
+ required_call_args = ("kv_caches", "input_ids", "attention_metadata")
371
+ missing_args = tuple(arg for arg in required_call_args
372
+ if not supports_kw(model_call, arg))
373
+
374
+ if missing_args:
375
+ raise TypeError(
376
+ f"Model {model.__name__} __call__ method is missing required "
377
+ f"keyword arguments: {missing_args}")
378
+
379
+
380
+ def register_model(arch: str, model: Any) -> None:
381
+ """
382
+ Registers a model class for a given architecture name.
383
+
384
+ This function registers the model with both the tpu_inference registry
385
+ and the vLLM registry. For vLLM, it creates a compatible wrapper
386
+ around the JAX model.
387
+
388
+ Args:
389
+ arch: The name of the architecture (e.g., "LlamaForCausalLM").
390
+ model: The JAX model class to register (e.g., a flax.nnx.Module).
391
+ """
392
+ _validate_model_interface(model)
393
+
394
+ # Register with tpu_inference registry for the JAX backend
395
+ _MODEL_REGISTRY[arch] = model
396
+
397
+ # Create a vLLM-compatible wrapper for the JAX model class.
398
+ # This wrapper inherits from the JAX model and torch.nn.Module
399
+ # to pass vLLM's type checks. It is not meant to be instantiated
400
+ # or executed by vLLM's PyTorch backend.
401
+ def unimplemented_forward(
402
+ self,
403
+ input_ids: "torch.Tensor",
404
+ positions: "torch.Tensor",
405
+ intermediate_tensors: Optional[Any] = None,
406
+ inputs_embeds: Optional["torch.Tensor"] = None,
407
+ ) -> None:
408
+ raise NotImplementedError(
409
+ "This is a JAX model and does not implement the PyTorch forward method."
410
+ )
411
+
412
+ # We need a custom __init__ that only calls torch.nn.Module's init,
413
+ # to avoid triggering JAX logic when vLLM inspects the class.
414
+ def wrapper_init(self, *args, **kwargs):
415
+ torch.nn.Module.__init__(self)
416
+
417
+ # Dynamically create the wrapper class that is a subclass of both the
418
+ # JAX model and torch.nn.Module.
419
+ VllmCompatibleModel = type(
420
+ f"VllmCompatible{model.__name__}",
421
+ (model, torch.nn.Module),
422
+ {
423
+ "__init__": wrapper_init,
424
+ "forward": unimplemented_forward,
425
+ # Prevent vLLM from trying to load weights into this dummy class.
426
+ "load_weights": lambda self, *args, **kwargs: None,
427
+ })
428
+
429
+ # Register the wrapped model with vLLM's registry.
430
+ from vllm.model_executor.models.registry import ModelRegistry
431
+ ModelRegistry.register_model(arch, VllmCompatibleModel)
432
+ logger.info(
433
+ f"Registered JAX model {arch} with tpu_inference and vLLM registries.")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1rc2
3
+ Version: 0.11.1rc3
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -21,9 +21,9 @@ Requires-Dist: pytest-mock
21
21
  Requires-Dist: absl-py
22
22
  Requires-Dist: numpy
23
23
  Requires-Dist: google-cloud-storage
24
- Requires-Dist: jax
25
- Requires-Dist: jaxlib
26
- Requires-Dist: libtpu
24
+ Requires-Dist: jax==0.7.2
25
+ Requires-Dist: jaxlib==0.7.2
26
+ Requires-Dist: libtpu==0.0.23
27
27
  Requires-Dist: jaxtyping
28
28
  Requires-Dist: flax==0.11.1
29
29
  Requires-Dist: torchax==0.0.7
@@ -47,7 +47,7 @@ Dynamic: requires-python
47
47
  </p>
48
48
 
49
49
  <p align="center">
50
- | <a href="https://github.com/vllm-project/tpu-inference/tree/main/docs"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh"><b>Developer Slack</b></a> |
50
+ | <a href="https://tpu.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh"><b>Developer Slack</b></a> |
51
51
  </p>
52
52
 
53
53
  ---
@@ -90,7 +90,7 @@ Visit our [documentation](https://github.com/vllm-project/tpu-inference/tree/mai
90
90
 
91
91
  ## Contribute
92
92
 
93
- We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) if it's your first time contributing.
93
+ We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue+state%3Aopen+label%3A%22good+first+issue%22) if it's your first time contributing.
94
94
 
95
95
  ## Contact us
96
96
 
@@ -52,6 +52,10 @@ tpu_inference/interfaces/platform.py,sha256=_EVTdilqpXJX2rRdypANuojOhDO0BCkUweka
52
52
  tpu_inference/interfaces/request.py,sha256=DRkjdWo5wmkVwQlq9DqpMDPeVPmQd6dfyhN2_k8tezw,950
53
53
  tpu_inference/interfaces/scheduler.py,sha256=cFBRkqVNXHrn-08Zvr9B23YTJUzSehy1rE-Fy2V5nvg,816
54
54
  tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
+ tpu_inference/kernels/collectives/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
+ tpu_inference/kernels/collectives/all_gather_matmul.py,sha256=0OYLLjlDmkRYScl7lHRi0o___5I5iMiW1gso-_dWSbc,27255
57
+ tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py,sha256=KdaOIzTfIgUR0CcUTA46tpYH-cxPNoJx2cTMEvHx-Ac,1399
58
+ tpu_inference/kernels/collectives/util.py,sha256=LbLD6lOxuszbUsykF89gWQqEJUICCZsfzam3EJDPnFE,1859
55
59
  tpu_inference/kernels/flash_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
60
  tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
57
61
  tpu_inference/kernels/quantized_matmul/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -67,6 +71,45 @@ tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-
67
71
  tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=zc-re4Knsdcfvt2oRO5KGD9-dJs0P8GVJ3yGtclHU2A,54740
68
72
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=KR2UFpCWjsXCmfMcxxV3yV2DVJp5xcEomOtOKYnSL78,131402
69
73
  tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=5ij66Rl7YsjTCH1UERP1W-XXC57sL6ZVPQdTLhMtKHQ,1010
74
+ tpu_inference/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
+ tpu_inference/layers/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
76
+ tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
77
+ tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
+ tpu_inference/layers/jax/attention_interface.py,sha256=bXBD8C8RTYTyLJOIGcKd1jH_ZruM0jabLj4n98RIKSA,12003
79
+ tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
80
+ tpu_inference/layers/jax/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
81
+ tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
82
+ tpu_inference/layers/jax/layers.py,sha256=yv_lC2tbJuzVL-OaXYooX82Ys8hWZATeH9M78coJ3VI,10633
83
+ tpu_inference/layers/jax/misc.py,sha256=znKv1Nuq_LgYpaIu0qlzUVDgQWnjjG7aqPJGM8kuwcw,566
84
+ tpu_inference/layers/jax/rope.py,sha256=3ZyR06vwliipkynHHrvcK-Q_aRhvQKDYBOqBYr3oWM8,7029
85
+ tpu_inference/layers/jax/rope_interface.py,sha256=X0SruXizlCHGnssFujC1pL07UC4Vsp7-gdBy_Q7JZhI,8375
86
+ tpu_inference/layers/jax/sharding.py,sha256=L0Uh92oLaXFNNQ0qqzNtBD3x3wnTRexQt8GzsCvqH1k,17874
87
+ tpu_inference/layers/jax/transformer_block.py,sha256=MBN4_hYCGq_-eyomGVUqplBZugZ2LBWUFOgM1UtUxFY,2952
88
+ tpu_inference/layers/jax/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
+ tpu_inference/layers/jax/attention/attention.py,sha256=KsGuQpOu7yUpimIr5XBniHKaa2ohx_Ke2YaCOvAG3jc,9837
90
+ tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=YlagoBMwINv2KRH1dr4oEcH_cQ9QMPB55nO2FQZsWs0,14010
91
+ tpu_inference/layers/jax/attention/llama4_attention.py,sha256=VvUmfBxQEbHf3F2BrcYDUnq5abj7CSDYeRsNx_eVAh0,6162
92
+ tpu_inference/layers/jax/moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
93
+ tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=Q6CuwwiZtWYm6iUee1wJoDJrwJE6_bcznTK2HrtXb0M,26089
94
+ tpu_inference/layers/jax/moe/moe.py,sha256=cA8R1rjbBwNEoNlsPWjeIBB9nvaRDwlEdwQTVg6lTpY,8762
95
+ tpu_inference/layers/jax/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
96
+ tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=IRfVWjkbVXp9Sv1YrGMMh-LYx1AwbY-3FTXEO1-Ue9g,20423
97
+ tpu_inference/layers/jax/sample/sampling.py,sha256=-47SC7AqU4UgyO91zAdYXTgrBfdlQ9I89HFZKwU0eQA,3223
98
+ tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=c3jHNjh1hkFJ5gxGTEk0qBOZnICeY3EELViF5Omp_Nc,2252
99
+ tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
100
+ tpu_inference/layers/vllm/attention.py,sha256=UVhuNCCrz6jdLNotjGtgaR4CVZ4zNmq5VhsiuOTi6_I,6649
101
+ tpu_inference/layers/vllm/fused_moe.py,sha256=ld_-sIHRdUY2tTTHrzHzCahVxH4P0sZVZrxYBQYSJhE,17455
102
+ tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
103
+ tpu_inference/layers/vllm/sharding.py,sha256=Ck2OzNiucHtrEutDqPQNteu8MEm6isIkE8U5ziowHgM,5779
104
+ tpu_inference/layers/vllm/quantization/__init__.py,sha256=UGv9cJftrBNoC0pU8SLnTLq3zvqMcolN5YJ6n_J5jf4,1392
105
+ tpu_inference/layers/vllm/quantization/awq.py,sha256=78H4AYgbvLCrW-5bGbn9_WM1J8KnRzVOInfKSW_QmzQ,8476
106
+ tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
107
+ tpu_inference/layers/vllm/quantization/unquantized.py,sha256=QIN6lWfVhN4ikUQlDbD8GhkZcLp1-s1Zi66aqKenmeo,10062
108
+ tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
109
+ tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=ifC6UsCY0tB6BO7X-PWtw-ikUc5IhcPcLvo0_RFrEsM,5253
110
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
111
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=6sQvsxiWdi5Vte8V9vrQ2abaqGqWpq-mtzU7lGAo-ac,8759
112
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
70
113
  tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
114
  tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
72
115
  tpu_inference/lora/torch_punica_tpu.py,sha256=ZfwWpPhkz4VQyxX9KeClx1hhchglsCWl0xpcGZsuMG0,12522
@@ -76,6 +119,8 @@ tpu_inference/mock/vllm_envs.py,sha256=hHtbFOM45T5EB2tEGecMGbJA0qOI9dmNYcjANgtah
76
119
  tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
77
120
  tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
78
121
  tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
122
+ tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
+ tpu_inference/models/common/model_loader.py,sha256=kOwc5Dyn433U0F-qZU1D0_k5USkMTY5Em0_WvQfjIYc,17661
79
124
  tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
125
  tpu_inference/models/jax/deepseek_v3.py,sha256=735PSgqxxrYL9JIsohhUXimjSNYMeNlepfRLrYHZ9us,40038
81
126
  tpu_inference/models/jax/llama3.py,sha256=bi-wIgZxR9h_DwoYHczPZXqrcvbzCVwnANuKnak6HcI,13024
@@ -116,8 +161,8 @@ tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
116
161
  tpu_inference/worker/_temporary_vllm_compat.py,sha256=GpF8TuPMDbc0fvIxe7XWEe69FES_F-jJnmcaTgf2dO8,5182
117
162
  tpu_inference/worker/base.py,sha256=0Dd3CKk3e7DgvzhfH4M-9-MEQNyYh4zUWSO4tnHFd6s,3140
118
163
  tpu_inference/worker/tpu_worker_jax.py,sha256=7b2QVTSbveifm9_BgNnVGwEvh5zPrEi1qiXXTwFFODc,14093
119
- tpu_inference-0.11.1rc2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
120
- tpu_inference-0.11.1rc2.dist-info/METADATA,sha256=IY3K7w2UtJtLo8TmAYWCjzLpQvEGsMVkshtWhDGiBqs,5423
121
- tpu_inference-0.11.1rc2.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
122
- tpu_inference-0.11.1rc2.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
123
- tpu_inference-0.11.1rc2.dist-info/RECORD,,
164
+ tpu_inference-0.11.1rc3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
165
+ tpu_inference-0.11.1rc3.dist-info/METADATA,sha256=eqqOrhIU_Ub_SoR33_ZJWJzszDrpuLZoi0iTBhcpfUo,5396
166
+ tpu_inference-0.11.1rc3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
167
+ tpu_inference-0.11.1rc3.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
168
+ tpu_inference-0.11.1rc3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5