tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tests/test_base.py ADDED
@@ -0,0 +1,201 @@
1
+ import logging
2
+ import unittest
3
+ import warnings
4
+ from dataclasses import dataclass, field, fields
5
+ from typing import Any, List, Mapping
6
+
7
+ from tpu_inference.layers.jax.base import Config
8
+
9
+ # Use the 'warnings' module to globally ignore warnings within this block
10
+ vllm_logger = logging.getLogger("vllm")
11
+ original_level = vllm_logger.level
12
+
13
+ with warnings.catch_warnings():
14
+ warnings.simplefilter("ignore")
15
+
16
+ # Set the vLLM logger to ERROR to suppress its messages
17
+ vllm_logger.setLevel(logging.ERROR)
18
+
19
+ # Import the class; all warnings will be suppressed
20
+ from vllm.config import ModelConfig
21
+
22
+ vllm_logger.setLevel(logging.WARNING)
23
+
24
+
25
+ def setup_vllm_config(subconfig_types: List[str],
26
+ overrides: List[Mapping[str, Any]]):
27
+ vllm_config = SimpleVllmConfig()
28
+ for (subconfig_type, override) in zip(subconfig_types, overrides):
29
+ if subconfig_type == "model":
30
+ for key in override:
31
+ setattr(vllm_config.model_config, key, override[key])
32
+ else:
33
+ for key in override:
34
+ setattr(vllm_config, key, override[key])
35
+ return vllm_config
36
+
37
+
38
+ @dataclass
39
+ class SimpleVllmConfig():
40
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
41
+ # Set default max_model_len to turn off warnings.
42
+ model_config: ModelConfig = field(
43
+ default_factory=lambda: ModelConfig(max_model_len=1024))
44
+
45
+
46
+ @dataclass
47
+ class SimpleConfig(Config):
48
+ vllm_config: SimpleVllmConfig
49
+ arg1: str
50
+ arg2: str
51
+ arg3: int
52
+
53
+ def is_equal(self, other: Config):
54
+ for f in fields(self):
55
+ if f.name != "vllm_config":
56
+ if getattr(self, f.name) != getattr(other, f.name):
57
+ return False
58
+ return True
59
+
60
+
61
+ class ConfigOverrideTests(unittest.TestCase):
62
+
63
+ def test_additional_config_overrides(self):
64
+ subconfig_types = ['']
65
+ overrides = [{"additional_config": {"arg1": "val1", "arg2": "val2"}}]
66
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
67
+ default_vllm_config = SimpleVllmConfig()
68
+ config = SimpleConfig(vllm_config=override_vllm_config,
69
+ arg1="foo",
70
+ arg2="bar",
71
+ arg3=123)
72
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
73
+ arg1="val1",
74
+ arg2="val2",
75
+ arg3=123)
76
+ self.assertTrue(config.is_equal(expected_config))
77
+
78
+ def test_hf_overrides(self):
79
+ subconfig_types = ['model']
80
+ overrides = [{"hf_overrides": {"arg2": "val2", "arg3": 456}}]
81
+ default_vllm_config = SimpleVllmConfig()
82
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
83
+ config = SimpleConfig(vllm_config=override_vllm_config,
84
+ arg1="foo",
85
+ arg2="bar",
86
+ arg3=123)
87
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
88
+ arg1="foo",
89
+ arg2="val2",
90
+ arg3=456)
91
+ self.assertTrue(config.is_equal(expected_config))
92
+
93
+ def test_additional_and_hf_overrides(self):
94
+ subconfig_types = ['', 'model']
95
+ overrides = [{
96
+ "additional_config": {
97
+ "arg1": "val1",
98
+ "arg2": "val2"
99
+ }
100
+ }, {
101
+ "hf_overrides": {
102
+ "arg2": "val3",
103
+ "arg3": 456
104
+ }
105
+ }]
106
+ default_vllm_config = SimpleVllmConfig()
107
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
108
+ config = SimpleConfig(vllm_config=override_vllm_config,
109
+ arg1="foo",
110
+ arg2="bar",
111
+ arg3=123)
112
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
113
+ arg1="val1",
114
+ arg2="val3",
115
+ arg3=456)
116
+ self.assertTrue(config.is_equal(expected_config))
117
+
118
+ def test_additional_and_generate_overrides(self):
119
+ subconfig_types = ['', 'model']
120
+ overrides = [{
121
+ "additional_config": {
122
+ "arg1": "val1",
123
+ "arg2": "val2"
124
+ }
125
+ }, {
126
+ "override_generation_config": {
127
+ "arg2": "val3",
128
+ "arg3": 456
129
+ }
130
+ }]
131
+ default_vllm_config = SimpleVllmConfig()
132
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
133
+ config = SimpleConfig(vllm_config=override_vllm_config,
134
+ arg1="foo",
135
+ arg2="bar",
136
+ arg3=123)
137
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
138
+ arg1="val1",
139
+ arg2="val3",
140
+ arg3=456)
141
+ self.assertTrue(config.is_equal(expected_config))
142
+
143
+ def test_hf_and_generate_overrides(self):
144
+ subconfig_types = ['model', 'model']
145
+ overrides = [{
146
+ "hf_overrides": {
147
+ "arg2": "val2",
148
+ "arg3": 456
149
+ }
150
+ }, {
151
+ "override_generation_config": {
152
+ "arg2": "val4",
153
+ "arg3": 789
154
+ }
155
+ }]
156
+ default_vllm_config = SimpleVllmConfig()
157
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
158
+ config = SimpleConfig(vllm_config=override_vllm_config,
159
+ arg1="foo",
160
+ arg2="bar",
161
+ arg3=123)
162
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
163
+ arg1="foo",
164
+ arg2="val4",
165
+ arg3=789)
166
+ self.assertTrue(config.is_equal(expected_config))
167
+
168
+ def test_additional_and_hf_and_generate_overrides(self):
169
+ subconfig_types = ['', 'model', 'model']
170
+ overrides = [{
171
+ "additional_config": {
172
+ "arg1": "val1",
173
+ "arg2": "val2"
174
+ }
175
+ }, {
176
+ "hf_overrides": {
177
+ "arg2": "val2",
178
+ "arg3": 456
179
+ }
180
+ }, {
181
+ "override_generation_config": {
182
+ "arg1": "val3",
183
+ "arg2": "val4",
184
+ "arg3": 789
185
+ }
186
+ }]
187
+ default_vllm_config = SimpleVllmConfig()
188
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
189
+ config = SimpleConfig(vllm_config=override_vllm_config,
190
+ arg1="foo",
191
+ arg2="bar",
192
+ arg3=123)
193
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
194
+ arg1="val3",
195
+ arg2="val4",
196
+ arg3=789)
197
+ self.assertTrue(config.is_equal(expected_config))
198
+
199
+
200
+ if __name__ == '__main__':
201
+ unittest.main()