tpu-inference 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (123) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  53. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  54. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  55. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  56. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  57. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  58. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  59. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  60. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  61. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  63. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  65. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  67. tpu_inference/logger.py +10 -0
  68. tpu_inference/lora/__init__.py +0 -0
  69. tpu_inference/lora/torch_lora_ops.py +103 -0
  70. tpu_inference/lora/torch_punica_tpu.py +308 -0
  71. tpu_inference/mock/__init__.py +0 -0
  72. tpu_inference/mock/vllm_config_utils.py +28 -0
  73. tpu_inference/mock/vllm_envs.py +1233 -0
  74. tpu_inference/mock/vllm_logger.py +212 -0
  75. tpu_inference/mock/vllm_logging_utils.py +15 -0
  76. tpu_inference/models/__init__.py +0 -0
  77. tpu_inference/models/jax/__init__.py +0 -0
  78. tpu_inference/models/jax/deepseek_v3.py +868 -0
  79. tpu_inference/models/jax/llama3.py +366 -0
  80. tpu_inference/models/jax/llama4.py +473 -0
  81. tpu_inference/models/jax/llama_eagle3.py +333 -0
  82. tpu_inference/models/jax/phi3.py +376 -0
  83. tpu_inference/models/jax/qwen2.py +375 -0
  84. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  85. tpu_inference/models/jax/qwen3.py +302 -0
  86. tpu_inference/models/jax/utils/__init__.py +0 -0
  87. tpu_inference/models/jax/utils/file_utils.py +96 -0
  88. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  89. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  90. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  91. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  92. tpu_inference/models/vllm/__init__.py +0 -0
  93. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  94. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  95. tpu_inference/platforms/__init__.py +2 -0
  96. tpu_inference/platforms/tpu_jax.py +257 -0
  97. tpu_inference/runner/__init__.py +0 -0
  98. tpu_inference/runner/block_table_jax.py +122 -0
  99. tpu_inference/runner/compilation_manager.py +672 -0
  100. tpu_inference/runner/input_batch_jax.py +435 -0
  101. tpu_inference/runner/kv_cache.py +119 -0
  102. tpu_inference/runner/kv_cache_manager.py +460 -0
  103. tpu_inference/runner/lora_utils.py +92 -0
  104. tpu_inference/runner/multimodal_manager.py +208 -0
  105. tpu_inference/runner/persistent_batch_manager.py +244 -0
  106. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  107. tpu_inference/runner/structured_decoding_manager.py +89 -0
  108. tpu_inference/runner/tpu_jax_runner.py +771 -0
  109. tpu_inference/runner/utils.py +426 -0
  110. tpu_inference/spec_decode/__init__.py +0 -0
  111. tpu_inference/spec_decode/jax/__init__.py +0 -0
  112. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  113. tpu_inference/tpu_info.py +77 -0
  114. tpu_inference/utils.py +294 -0
  115. tpu_inference/worker/__init__.py +0 -0
  116. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  117. tpu_inference/worker/base.py +100 -0
  118. tpu_inference/worker/tpu_worker_jax.py +321 -0
  119. tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
  120. tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
  121. tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
  122. tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
  123. tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
tests/test_base.py ADDED
@@ -0,0 +1,201 @@
1
+ import logging
2
+ import unittest
3
+ import warnings
4
+ from dataclasses import dataclass, field, fields
5
+ from typing import Any, List, Mapping
6
+
7
+ from tpu_inference.layers.jax.base import Config
8
+
9
+ # Use the 'warnings' module to globally ignore warnings within this block
10
+ vllm_logger = logging.getLogger("vllm")
11
+ original_level = vllm_logger.level
12
+
13
+ with warnings.catch_warnings():
14
+ warnings.simplefilter("ignore")
15
+
16
+ # Set the vLLM logger to ERROR to suppress its messages
17
+ vllm_logger.setLevel(logging.ERROR)
18
+
19
+ # Import the class; all warnings will be suppressed
20
+ from vllm.config import ModelConfig
21
+
22
+ vllm_logger.setLevel(logging.WARNING)
23
+
24
+
25
+ def setup_vllm_config(subconfig_types: List[str],
26
+ overrides: List[Mapping[str, Any]]):
27
+ vllm_config = SimpleVllmConfig()
28
+ for (subconfig_type, override) in zip(subconfig_types, overrides):
29
+ if subconfig_type == "model":
30
+ for key in override:
31
+ setattr(vllm_config.model_config, key, override[key])
32
+ else:
33
+ for key in override:
34
+ setattr(vllm_config, key, override[key])
35
+ return vllm_config
36
+
37
+
38
+ @dataclass
39
+ class SimpleVllmConfig():
40
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
41
+ # Set default max_model_len to turn off warnings.
42
+ model_config: ModelConfig = field(
43
+ default_factory=lambda: ModelConfig(max_model_len=1024))
44
+
45
+
46
+ @dataclass
47
+ class SimpleConfig(Config):
48
+ vllm_config: SimpleVllmConfig
49
+ arg1: str
50
+ arg2: str
51
+ arg3: int
52
+
53
+ def is_equal(self, other: Config):
54
+ for f in fields(self):
55
+ if f.name != "vllm_config":
56
+ if getattr(self, f.name) != getattr(other, f.name):
57
+ return False
58
+ return True
59
+
60
+
61
+ class ConfigOverrideTests(unittest.TestCase):
62
+
63
+ def test_additional_config_overrides(self):
64
+ subconfig_types = ['']
65
+ overrides = [{"additional_config": {"arg1": "val1", "arg2": "val2"}}]
66
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
67
+ default_vllm_config = SimpleVllmConfig()
68
+ config = SimpleConfig(vllm_config=override_vllm_config,
69
+ arg1="foo",
70
+ arg2="bar",
71
+ arg3=123)
72
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
73
+ arg1="val1",
74
+ arg2="val2",
75
+ arg3=123)
76
+ self.assertTrue(config.is_equal(expected_config))
77
+
78
+ def test_hf_overrides(self):
79
+ subconfig_types = ['model']
80
+ overrides = [{"hf_overrides": {"arg2": "val2", "arg3": 456}}]
81
+ default_vllm_config = SimpleVllmConfig()
82
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
83
+ config = SimpleConfig(vllm_config=override_vllm_config,
84
+ arg1="foo",
85
+ arg2="bar",
86
+ arg3=123)
87
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
88
+ arg1="foo",
89
+ arg2="val2",
90
+ arg3=456)
91
+ self.assertTrue(config.is_equal(expected_config))
92
+
93
+ def test_additional_and_hf_overrides(self):
94
+ subconfig_types = ['', 'model']
95
+ overrides = [{
96
+ "additional_config": {
97
+ "arg1": "val1",
98
+ "arg2": "val2"
99
+ }
100
+ }, {
101
+ "hf_overrides": {
102
+ "arg2": "val3",
103
+ "arg3": 456
104
+ }
105
+ }]
106
+ default_vllm_config = SimpleVllmConfig()
107
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
108
+ config = SimpleConfig(vllm_config=override_vllm_config,
109
+ arg1="foo",
110
+ arg2="bar",
111
+ arg3=123)
112
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
113
+ arg1="val1",
114
+ arg2="val3",
115
+ arg3=456)
116
+ self.assertTrue(config.is_equal(expected_config))
117
+
118
+ def test_additional_and_generate_overrides(self):
119
+ subconfig_types = ['', 'model']
120
+ overrides = [{
121
+ "additional_config": {
122
+ "arg1": "val1",
123
+ "arg2": "val2"
124
+ }
125
+ }, {
126
+ "override_generation_config": {
127
+ "arg2": "val3",
128
+ "arg3": 456
129
+ }
130
+ }]
131
+ default_vllm_config = SimpleVllmConfig()
132
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
133
+ config = SimpleConfig(vllm_config=override_vllm_config,
134
+ arg1="foo",
135
+ arg2="bar",
136
+ arg3=123)
137
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
138
+ arg1="val1",
139
+ arg2="val3",
140
+ arg3=456)
141
+ self.assertTrue(config.is_equal(expected_config))
142
+
143
+ def test_hf_and_generate_overrides(self):
144
+ subconfig_types = ['model', 'model']
145
+ overrides = [{
146
+ "hf_overrides": {
147
+ "arg2": "val2",
148
+ "arg3": 456
149
+ }
150
+ }, {
151
+ "override_generation_config": {
152
+ "arg2": "val4",
153
+ "arg3": 789
154
+ }
155
+ }]
156
+ default_vllm_config = SimpleVllmConfig()
157
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
158
+ config = SimpleConfig(vllm_config=override_vllm_config,
159
+ arg1="foo",
160
+ arg2="bar",
161
+ arg3=123)
162
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
163
+ arg1="foo",
164
+ arg2="val4",
165
+ arg3=789)
166
+ self.assertTrue(config.is_equal(expected_config))
167
+
168
+ def test_additional_and_hf_and_generate_overrides(self):
169
+ subconfig_types = ['', 'model', 'model']
170
+ overrides = [{
171
+ "additional_config": {
172
+ "arg1": "val1",
173
+ "arg2": "val2"
174
+ }
175
+ }, {
176
+ "hf_overrides": {
177
+ "arg2": "val2",
178
+ "arg3": 456
179
+ }
180
+ }, {
181
+ "override_generation_config": {
182
+ "arg1": "val3",
183
+ "arg2": "val4",
184
+ "arg3": 789
185
+ }
186
+ }]
187
+ default_vllm_config = SimpleVllmConfig()
188
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
189
+ config = SimpleConfig(vllm_config=override_vllm_config,
190
+ arg1="foo",
191
+ arg2="bar",
192
+ arg3=123)
193
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
194
+ arg1="val3",
195
+ arg2="val4",
196
+ arg3=789)
197
+ self.assertTrue(config.is_equal(expected_config))
198
+
199
+
200
+ if __name__ == '__main__':
201
+ unittest.main()