tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,133 @@
1
+ # https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
2
+ import os
3
+ import time
4
+
5
+ import pytest
6
+ import vllm
7
+ from vllm.lora.request import LoRARequest
8
+
9
+ # This file contains tests to ensure that LoRA works correctly on the TPU
10
+ # backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
11
+ # for this. The adapters are:
12
+ # Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
13
+ # from 1 to 4.
14
+
15
+ # These adapters are trained using a standard huggingface peft training script,
16
+ # where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
17
+ # 100 training iterations with a training batch size of 100.
18
+
19
+
20
+ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
21
+ return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
22
+ max_model_len=256,
23
+ max_num_batched_tokens=64,
24
+ max_num_seqs=8,
25
+ tensor_parallel_size=tp,
26
+ enable_lora=True,
27
+ max_loras=num_loras,
28
+ max_lora_rank=8)
29
+
30
+
31
+ # For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
32
+ TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
33
+
34
+
35
+ @pytest.mark.parametrize("tp", TP)
36
+ def test_single_lora(tp):
37
+ """
38
+ This test ensures we can run a single LoRA adapter on the TPU backend.
39
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
40
+ will force Qwen2.5-3B-Instruct to claim 1+1=2.
41
+ """
42
+
43
+ llm = setup_vllm(1, tp)
44
+
45
+ prompt = "What is 1+1? \n"
46
+
47
+ lora_request = LoRARequest(
48
+ "lora_adapter_2", 2,
49
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
50
+ output = llm.generate(prompt,
51
+ sampling_params=vllm.SamplingParams(max_tokens=16,
52
+ temperature=0),
53
+ lora_request=lora_request)[0].outputs[0].text
54
+
55
+ answer = output.strip()[0]
56
+
57
+ assert answer.isdigit()
58
+ assert int(answer) == 2
59
+
60
+ del llm
61
+ time.sleep(10)
62
+
63
+
64
+ @pytest.mark.parametrize("tp", TP)
65
+ def test_lora_hotswapping(tp):
66
+ """
67
+ This test ensures we can run multiple LoRA adapters on the TPU backend, even
68
+ if we only have space to store 1.
69
+
70
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
71
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
72
+ """
73
+
74
+ lora_name_template = \
75
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
76
+ lora_requests = [
77
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
78
+ for i in range(1, 5)
79
+ ]
80
+
81
+ llm = setup_vllm(1, tp)
82
+
83
+ prompt = "What is 1+1? \n"
84
+
85
+ for i, req in enumerate(lora_requests):
86
+ output = llm.generate(prompt,
87
+ sampling_params=vllm.SamplingParams(
88
+ max_tokens=16, temperature=0),
89
+ lora_request=req)[0].outputs[0].text
90
+ answer = output.strip()[0]
91
+
92
+ assert answer.isdigit()
93
+ assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"
94
+
95
+ del llm
96
+ time.sleep(10)
97
+
98
+
99
+ @pytest.mark.parametrize("tp", TP)
100
+ def test_multi_lora(tp):
101
+ """
102
+ This test ensures we can run multiple LoRA adapters on the TPU backend, when
103
+ we have enough space to store all of them.
104
+
105
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
106
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
107
+ """
108
+ lora_name_template = \
109
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
110
+ lora_requests = [
111
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
112
+ for i in range(1, 5)
113
+ ]
114
+
115
+ llm = setup_vllm(4, tp)
116
+
117
+ prompt = "What is 1+1? \n"
118
+
119
+ for i, req in enumerate(lora_requests):
120
+ output = llm.generate(prompt,
121
+ sampling_params=vllm.SamplingParams(
122
+ max_tokens=16, temperature=0),
123
+ lora_request=req)[0].outputs[0].text
124
+
125
+ answer = output.strip()[0]
126
+
127
+ assert answer.isdigit()
128
+ assert int(
129
+ output.strip()
130
+ [0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"
131
+
132
+ del llm
133
+ time.sleep(10)
tests/lora/utils.py ADDED
@@ -0,0 +1,96 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import torch
5
+ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
6
+
7
+
8
+ # https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py
9
+ class DummyLoRAManager:
10
+
11
+ def __init__(self, device: torch.device = "cuda:0"):
12
+ super().__init__()
13
+ self._loras: dict[str, LoRALayerWeights] = {}
14
+ self._device = device
15
+
16
+ def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
17
+ self._loras[module_name] = lora
18
+
19
+ def get_module_lora(self, module_name: str) -> LoRALayerWeights:
20
+ return self._loras[module_name]
21
+
22
+ def init_random_lora(
23
+ self,
24
+ module_name: str,
25
+ weight: torch.Tensor,
26
+ rank: int = 8,
27
+ generate_embeddings_tensor: int = 0,
28
+ ):
29
+ lora = LoRALayerWeights(
30
+ module_name,
31
+ rank=rank,
32
+ lora_alpha=1,
33
+ lora_a=torch.rand([rank, weight.shape[1]],
34
+ dtype=weight.dtype,
35
+ device=self._device),
36
+ lora_b=torch.rand([weight.shape[0], rank],
37
+ dtype=weight.dtype,
38
+ device=self._device),
39
+ )
40
+ if generate_embeddings_tensor:
41
+ lora.embeddings_tensor = torch.rand(
42
+ 5,
43
+ generate_embeddings_tensor,
44
+ dtype=weight.dtype,
45
+ device=self._device,
46
+ )
47
+ self.set_module_lora(module_name, lora)
48
+
49
+ return lora
50
+
51
+ def init_lora(
52
+ self,
53
+ module_name: str,
54
+ input_dim: int,
55
+ output_dim: int,
56
+ rank=8,
57
+ noop=False,
58
+ embeddings_tensor=None,
59
+ ):
60
+ lora = LoRALayerWeights(
61
+ module_name,
62
+ rank=rank,
63
+ lora_alpha=1,
64
+ lora_a=torch.rand([rank, input_dim], device="cuda"),
65
+ lora_b=torch.rand([output_dim, input_dim], device="cuda"),
66
+ embeddings_tensor=embeddings_tensor,
67
+ )
68
+ self.set_module_lora(module_name, lora)
69
+ return lora
70
+
71
+ def reset_lora(self):
72
+ self._loras = {}
73
+
74
+ def init_packed_lora(
75
+ self,
76
+ module_name: str,
77
+ input_dim: int,
78
+ output_dims: list[int],
79
+ noop_lora_index: list[int] | None = None,
80
+ rank: int = 8,
81
+ ):
82
+ base_loras: list[LoRALayerWeights] = []
83
+ noop_lora_index_set = set(noop_lora_index or [])
84
+
85
+ for i, out_dim in enumerate(output_dims):
86
+ base_lora = self.init_lora(
87
+ module_name + "_000_" + str(i),
88
+ input_dim,
89
+ out_dim,
90
+ rank=rank,
91
+ noop=i in noop_lora_index_set,
92
+ )
93
+ base_loras.append(base_lora)
94
+ packed_lora = PackedLoRALayerWeights.pack(base_loras)
95
+ self.set_module_lora(module_name, packed_lora)
96
+ return packed_lora
tests/test_base.py ADDED
@@ -0,0 +1,201 @@
1
+ import logging
2
+ import unittest
3
+ import warnings
4
+ from dataclasses import dataclass, field, fields
5
+ from typing import Any, List, Mapping
6
+
7
+ from tpu_inference.layers.jax.base import Config
8
+
9
+ # Use the 'warnings' module to globally ignore warnings within this block
10
+ vllm_logger = logging.getLogger("vllm")
11
+ original_level = vllm_logger.level
12
+
13
+ with warnings.catch_warnings():
14
+ warnings.simplefilter("ignore")
15
+
16
+ # Set the vLLM logger to ERROR to suppress its messages
17
+ vllm_logger.setLevel(logging.ERROR)
18
+
19
+ # Import the class; all warnings will be suppressed
20
+ from vllm.config import ModelConfig
21
+
22
+ vllm_logger.setLevel(logging.WARNING)
23
+
24
+
25
+ def setup_vllm_config(subconfig_types: List[str],
26
+ overrides: List[Mapping[str, Any]]):
27
+ vllm_config = SimpleVllmConfig()
28
+ for (subconfig_type, override) in zip(subconfig_types, overrides):
29
+ if subconfig_type == "model":
30
+ for key in override:
31
+ setattr(vllm_config.model_config, key, override[key])
32
+ else:
33
+ for key in override:
34
+ setattr(vllm_config, key, override[key])
35
+ return vllm_config
36
+
37
+
38
+ @dataclass
39
+ class SimpleVllmConfig():
40
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
41
+ # Set default max_model_len to turn off warnings.
42
+ model_config: ModelConfig = field(
43
+ default_factory=lambda: ModelConfig(max_model_len=1024))
44
+
45
+
46
+ @dataclass
47
+ class SimpleConfig(Config):
48
+ vllm_config: SimpleVllmConfig
49
+ arg1: str
50
+ arg2: str
51
+ arg3: int
52
+
53
+ def is_equal(self, other: Config):
54
+ for f in fields(self):
55
+ if f.name != "vllm_config":
56
+ if getattr(self, f.name) != getattr(other, f.name):
57
+ return False
58
+ return True
59
+
60
+
61
+ class ConfigOverrideTests(unittest.TestCase):
62
+
63
+ def test_additional_config_overrides(self):
64
+ subconfig_types = ['']
65
+ overrides = [{"additional_config": {"arg1": "val1", "arg2": "val2"}}]
66
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
67
+ default_vllm_config = SimpleVllmConfig()
68
+ config = SimpleConfig(vllm_config=override_vllm_config,
69
+ arg1="foo",
70
+ arg2="bar",
71
+ arg3=123)
72
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
73
+ arg1="val1",
74
+ arg2="val2",
75
+ arg3=123)
76
+ self.assertTrue(config.is_equal(expected_config))
77
+
78
+ def test_hf_overrides(self):
79
+ subconfig_types = ['model']
80
+ overrides = [{"hf_overrides": {"arg2": "val2", "arg3": 456}}]
81
+ default_vllm_config = SimpleVllmConfig()
82
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
83
+ config = SimpleConfig(vllm_config=override_vllm_config,
84
+ arg1="foo",
85
+ arg2="bar",
86
+ arg3=123)
87
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
88
+ arg1="foo",
89
+ arg2="val2",
90
+ arg3=456)
91
+ self.assertTrue(config.is_equal(expected_config))
92
+
93
+ def test_additional_and_hf_overrides(self):
94
+ subconfig_types = ['', 'model']
95
+ overrides = [{
96
+ "additional_config": {
97
+ "arg1": "val1",
98
+ "arg2": "val2"
99
+ }
100
+ }, {
101
+ "hf_overrides": {
102
+ "arg2": "val3",
103
+ "arg3": 456
104
+ }
105
+ }]
106
+ default_vllm_config = SimpleVllmConfig()
107
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
108
+ config = SimpleConfig(vllm_config=override_vllm_config,
109
+ arg1="foo",
110
+ arg2="bar",
111
+ arg3=123)
112
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
113
+ arg1="val1",
114
+ arg2="val3",
115
+ arg3=456)
116
+ self.assertTrue(config.is_equal(expected_config))
117
+
118
+ def test_additional_and_generate_overrides(self):
119
+ subconfig_types = ['', 'model']
120
+ overrides = [{
121
+ "additional_config": {
122
+ "arg1": "val1",
123
+ "arg2": "val2"
124
+ }
125
+ }, {
126
+ "override_generation_config": {
127
+ "arg2": "val3",
128
+ "arg3": 456
129
+ }
130
+ }]
131
+ default_vllm_config = SimpleVllmConfig()
132
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
133
+ config = SimpleConfig(vllm_config=override_vllm_config,
134
+ arg1="foo",
135
+ arg2="bar",
136
+ arg3=123)
137
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
138
+ arg1="val1",
139
+ arg2="val3",
140
+ arg3=456)
141
+ self.assertTrue(config.is_equal(expected_config))
142
+
143
+ def test_hf_and_generate_overrides(self):
144
+ subconfig_types = ['model', 'model']
145
+ overrides = [{
146
+ "hf_overrides": {
147
+ "arg2": "val2",
148
+ "arg3": 456
149
+ }
150
+ }, {
151
+ "override_generation_config": {
152
+ "arg2": "val4",
153
+ "arg3": 789
154
+ }
155
+ }]
156
+ default_vllm_config = SimpleVllmConfig()
157
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
158
+ config = SimpleConfig(vllm_config=override_vllm_config,
159
+ arg1="foo",
160
+ arg2="bar",
161
+ arg3=123)
162
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
163
+ arg1="foo",
164
+ arg2="val4",
165
+ arg3=789)
166
+ self.assertTrue(config.is_equal(expected_config))
167
+
168
+ def test_additional_and_hf_and_generate_overrides(self):
169
+ subconfig_types = ['', 'model', 'model']
170
+ overrides = [{
171
+ "additional_config": {
172
+ "arg1": "val1",
173
+ "arg2": "val2"
174
+ }
175
+ }, {
176
+ "hf_overrides": {
177
+ "arg2": "val2",
178
+ "arg3": 456
179
+ }
180
+ }, {
181
+ "override_generation_config": {
182
+ "arg1": "val3",
183
+ "arg2": "val4",
184
+ "arg3": 789
185
+ }
186
+ }]
187
+ default_vllm_config = SimpleVllmConfig()
188
+ override_vllm_config = setup_vllm_config(subconfig_types, overrides)
189
+ config = SimpleConfig(vllm_config=override_vllm_config,
190
+ arg1="foo",
191
+ arg2="bar",
192
+ arg3=123)
193
+ expected_config = SimpleConfig(vllm_config=default_vllm_config,
194
+ arg1="val3",
195
+ arg2="val4",
196
+ arg3=789)
197
+ self.assertTrue(config.is_equal(expected_config))
198
+
199
+
200
+ if __name__ == '__main__':
201
+ unittest.main()
tests/test_envs.py ADDED
@@ -0,0 +1,182 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import pytest
5
+
6
+ import tpu_inference.envs as envs
7
+ from tpu_inference.envs import enable_envs_cache, environment_variables
8
+
9
+
10
+ def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
11
+ assert envs.JAX_PLATFORMS == ""
12
+ assert envs.PHASED_PROFILING_DIR == ""
13
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
14
+ monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
15
+ assert envs.JAX_PLATFORMS == "tpu"
16
+ assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
17
+
18
+ assert envs.TPU_NAME is None
19
+ assert envs.TPU_ACCELERATOR_TYPE is None
20
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
21
+ monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
22
+ assert envs.TPU_NAME == "my-tpu"
23
+ assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
24
+
25
+ # __getattr__ is not decorated with functools.cache
26
+ assert not hasattr(envs.__getattr__, "cache_info")
27
+
28
+
29
+ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
30
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
31
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
32
+
33
+ # __getattr__ is not decorated with functools.cache
34
+ assert not hasattr(envs.__getattr__, "cache_info")
35
+
36
+ enable_envs_cache()
37
+
38
+ # __getattr__ is decorated with functools.cache
39
+ assert hasattr(envs.__getattr__, "cache_info")
40
+ start_hits = envs.__getattr__.cache_info().hits
41
+
42
+ # 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
43
+ assert envs.JAX_PLATFORMS == "tpu"
44
+ assert envs.TPU_NAME == "my-tpu"
45
+ assert envs.__getattr__.cache_info().hits == start_hits + 2
46
+
47
+ # All environment variables are cached
48
+ for environment_variable in environment_variables:
49
+ envs.__getattr__(environment_variable)
50
+ assert envs.__getattr__.cache_info(
51
+ ).hits == start_hits + 2 + len(environment_variables)
52
+
53
+ # Reset envs.__getattr__ back to non-cached version to
54
+ # avoid affecting other tests
55
+ envs.__getattr__ = envs.__getattr__.__wrapped__
56
+
57
+
58
+ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Test SKIP_JAX_PRECOMPILE (default False)
60
+ assert envs.SKIP_JAX_PRECOMPILE is False
61
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
62
+ assert envs.SKIP_JAX_PRECOMPILE is True
63
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
+ assert envs.SKIP_JAX_PRECOMPILE is False
65
+
66
+ # Test NEW_MODEL_DESIGN (default False)
67
+ assert envs.NEW_MODEL_DESIGN is False
68
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
69
+ assert envs.NEW_MODEL_DESIGN is True
70
+
71
+ # Test USE_MOE_EP_KERNEL (default False)
72
+ assert envs.USE_MOE_EP_KERNEL is False
73
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
74
+ assert envs.USE_MOE_EP_KERNEL is True
75
+
76
+
77
+ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
78
+ assert envs.PYTHON_TRACER_LEVEL == 1
79
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
+ assert envs.PYTHON_TRACER_LEVEL == 3
81
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
+ assert envs.PYTHON_TRACER_LEVEL == 0
83
+
84
+
85
+ def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
+
89
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
90
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
+
92
+
93
+ def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
+ monkeypatch.delenv("JAX_PLATFORMS", raising=False)
95
+ monkeypatch.delenv("PREFILL_SLICES", raising=False)
96
+ monkeypatch.delenv("DECODE_SLICES", raising=False)
97
+
98
+ assert envs.JAX_PLATFORMS == ""
99
+ assert envs.PREFILL_SLICES == ""
100
+ assert envs.DECODE_SLICES == ""
101
+ assert envs.PHASED_PROFILING_DIR == ""
102
+
103
+
104
+ def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
105
+ monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
106
+ monkeypatch.delenv("TPU_NAME", raising=False)
107
+ monkeypatch.delenv("TPU_WORKER_ID", raising=False)
108
+
109
+ assert envs.TPU_ACCELERATOR_TYPE is None
110
+ assert envs.TPU_NAME is None
111
+ assert envs.TPU_WORKER_ID is None
112
+
113
+
114
+ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
115
+ assert envs.RAY_USAGE_STATS_ENABLED == "0"
116
+ monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
117
+ assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
+
119
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
+ monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
+
123
+
124
+ def test_invalid_attribute_raises_error():
125
+ with pytest.raises(AttributeError,
126
+ match="has no attribute 'NONEXISTENT_VAR'"):
127
+ _ = envs.NONEXISTENT_VAR
128
+
129
+
130
+ def test_dir_returns_all_env_vars():
131
+ env_vars = envs.__dir__()
132
+ assert isinstance(env_vars, list)
133
+ assert len(env_vars) == len(environment_variables)
134
+ assert "JAX_PLATFORMS" in env_vars
135
+ assert "TPU_NAME" in env_vars
136
+ assert "SKIP_JAX_PRECOMPILE" in env_vars
137
+ assert "MODEL_IMPL_TYPE" in env_vars
138
+
139
+
140
+ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
+ monkeypatch.setenv("TPU_WORKER_ID", "0")
142
+ assert envs.TPU_WORKER_ID == "0"
143
+
144
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
+
147
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
+ assert envs.TPU_MULTIHOST_BACKEND == "xla"
149
+
150
+
151
+ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
152
+ monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
153
+ assert envs.PREFILL_SLICES == "0,1,2,3"
154
+
155
+ monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
156
+ assert envs.DECODE_SLICES == "4,5,6,7"
157
+
158
+
159
+ def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
160
+ monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
161
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
162
+
163
+
164
+ def test_cache_preserves_values_across_env_changes(
165
+ monkeypatch: pytest.MonkeyPatch):
166
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
167
+
168
+ enable_envs_cache()
169
+
170
+ assert envs.JAX_PLATFORMS == "tpu"
171
+
172
+ # Change environment variable
173
+ monkeypatch.setenv("JAX_PLATFORMS", "cpu")
174
+
175
+ # Cached value should still be "tpu"
176
+ assert envs.JAX_PLATFORMS == "tpu"
177
+
178
+ # Reset envs.__getattr__ back to non-cached version
179
+ envs.__getattr__ = envs.__getattr__.__wrapped__
180
+
181
+ # Now it should reflect the new value
182
+ assert envs.JAX_PLATFORMS == "cpu"