tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py
CHANGED
|
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
|
18
18
|
ReplicatedLinearWithLoRA,
|
|
19
19
|
RowParallelLinearWithLoRA)
|
|
20
20
|
# yapf: enable
|
|
21
|
-
from vllm.lora.
|
|
21
|
+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
|
22
22
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
23
23
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
24
24
|
MergedColumnParallelLinear,
|
|
@@ -91,7 +91,6 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
-
generate_embeddings_tensor: int = 0,
|
|
95
94
|
repeats: int = 1,
|
|
96
95
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
96
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -103,8 +102,6 @@ def populate_loras(
|
|
|
103
102
|
lora_layer: the LoRAlayer to populate.
|
|
104
103
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
104
|
weights.
|
|
106
|
-
generate_embeddings_tensor: whether to generate an
|
|
107
|
-
embeddings tensor for each LoRA.
|
|
108
105
|
repeats: must only be set for column parallel packed
|
|
109
106
|
layers. Indicates the number of loras to compose
|
|
110
107
|
together to create a single lora layer.
|
|
@@ -131,7 +128,6 @@ def populate_loras(
|
|
|
131
128
|
baselayer_weights.device).init_random_lora(
|
|
132
129
|
module_name=f"fake_{i}",
|
|
133
130
|
weight=baselayer_weights,
|
|
134
|
-
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
131
|
)
|
|
136
132
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
133
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -147,7 +143,6 @@ def populate_loras(
|
|
|
147
143
|
slot_idx,
|
|
148
144
|
lora_a=lora.lora_a,
|
|
149
145
|
lora_b=lora.lora_b,
|
|
150
|
-
embeddings_tensor=lora.embeddings_tensor,
|
|
151
146
|
)
|
|
152
147
|
|
|
153
148
|
lora_dict[lora_id] = lora
|
|
@@ -210,6 +205,9 @@ def create_random_inputs(
|
|
|
210
205
|
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
|
211
206
|
@pytest.mark.parametrize("stage", [True, False])
|
|
212
207
|
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
|
|
208
|
+
# TODO(Qiliang Cui): Remove when issue is resolved.
|
|
209
|
+
if 'TPU7x' in jax.devices()[0].device_kind:
|
|
210
|
+
pytest.skip("Skipping test on TPU TPU7x.")
|
|
213
211
|
set_random_seed(6)
|
|
214
212
|
|
|
215
213
|
max_loras = 9
|
|
@@ -546,7 +544,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
546
544
|
index_to_id,
|
|
547
545
|
lora_config.max_loras,
|
|
548
546
|
vocab_size=512,
|
|
549
|
-
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
547
|
)
|
|
551
548
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
549
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import vllm
|
|
6
|
+
from vllm.lora.request import LoRARequest
|
|
7
|
+
|
|
8
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.parametrize("tp", TP)
|
|
12
|
+
def test_lora_performance(tp):
|
|
13
|
+
prompt = "What is 1+1? \n"
|
|
14
|
+
llm_without_lora = vllm.LLM(
|
|
15
|
+
model="Qwen/Qwen2.5-3B-Instruct",
|
|
16
|
+
max_model_len=256,
|
|
17
|
+
max_num_batched_tokens=64,
|
|
18
|
+
max_num_seqs=8,
|
|
19
|
+
tensor_parallel_size=tp,
|
|
20
|
+
)
|
|
21
|
+
start_time = time.time()
|
|
22
|
+
llm_without_lora.generate(
|
|
23
|
+
prompt,
|
|
24
|
+
sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
|
|
25
|
+
)[0].outputs[0].text
|
|
26
|
+
base_time = time.time() - start_time
|
|
27
|
+
|
|
28
|
+
del llm_without_lora
|
|
29
|
+
# Waiting for TPUs to be released
|
|
30
|
+
time.sleep(10)
|
|
31
|
+
|
|
32
|
+
llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
33
|
+
max_model_len=256,
|
|
34
|
+
max_num_batched_tokens=64,
|
|
35
|
+
max_num_seqs=8,
|
|
36
|
+
tensor_parallel_size=tp,
|
|
37
|
+
enable_lora=True,
|
|
38
|
+
max_loras=1,
|
|
39
|
+
max_lora_rank=8)
|
|
40
|
+
lora_request = LoRARequest(
|
|
41
|
+
"lora_adapter_2", 2,
|
|
42
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
43
|
+
start_time = time.time()
|
|
44
|
+
llm_with_lora.generate(prompt,
|
|
45
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
46
|
+
temperature=0),
|
|
47
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
48
|
+
lora_time = time.time() - start_time
|
|
49
|
+
print(f"Base time: {base_time}, LoRA time: {lora_time}")
|
|
50
|
+
assert (base_time /
|
|
51
|
+
lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
|
|
52
|
+
|
|
53
|
+
del llm_with_lora
|
tests/lora/utils.py
CHANGED
|
@@ -24,7 +24,6 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
-
generate_embeddings_tensor: int = 0,
|
|
28
27
|
):
|
|
29
28
|
lora = LoRALayerWeights(
|
|
30
29
|
module_name,
|
|
@@ -37,13 +36,6 @@ class DummyLoRAManager:
|
|
|
37
36
|
dtype=weight.dtype,
|
|
38
37
|
device=self._device),
|
|
39
38
|
)
|
|
40
|
-
if generate_embeddings_tensor:
|
|
41
|
-
lora.embeddings_tensor = torch.rand(
|
|
42
|
-
5,
|
|
43
|
-
generate_embeddings_tensor,
|
|
44
|
-
dtype=weight.dtype,
|
|
45
|
-
device=self._device,
|
|
46
|
-
)
|
|
47
39
|
self.set_module_lora(module_name, lora)
|
|
48
40
|
|
|
49
41
|
return lora
|
tests/test_envs.py
CHANGED
|
@@ -56,6 +56,13 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
64
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
65
|
+
|
|
59
66
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
67
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
68
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -63,6 +70,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
63
70
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
71
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
72
|
|
|
73
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
74
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
75
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
76
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
77
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
78
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
79
|
+
|
|
66
80
|
# Test NEW_MODEL_DESIGN (default False)
|
|
67
81
|
assert envs.NEW_MODEL_DESIGN is False
|
|
68
82
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -73,22 +87,110 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
73
87
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
88
|
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
89
|
|
|
90
|
+
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
+
|
|
99
|
+
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
+
|
|
103
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
+
|
|
106
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
+
|
|
109
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
+
|
|
112
|
+
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
+
|
|
116
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
+
|
|
119
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
+
|
|
123
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
+
|
|
126
|
+
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
+
|
|
130
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
+
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
+
|
|
137
|
+
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
+
with pytest.raises(
|
|
140
|
+
ValueError,
|
|
141
|
+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
+
with pytest.raises(ValueError,
|
|
146
|
+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
+
|
|
149
|
+
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
+
with pytest.raises(
|
|
152
|
+
ValueError,
|
|
153
|
+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
+
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
+
"""Test that empty string returns default value"""
|
|
159
|
+
|
|
160
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
+
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
+
|
|
163
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
+
|
|
76
166
|
|
|
77
167
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
168
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
169
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
170
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
171
|
+
|
|
78
172
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
173
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
174
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
175
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
176
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
177
|
|
|
178
|
+
# Test NUM_SLICES (default 1)
|
|
179
|
+
assert envs.NUM_SLICES == 1
|
|
180
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
181
|
+
assert envs.NUM_SLICES == 2
|
|
182
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
183
|
+
assert envs.NUM_SLICES == 4
|
|
84
184
|
|
|
85
|
-
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
-
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
185
|
|
|
89
|
-
|
|
186
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
187
|
+
# Test case sensitive choices
|
|
188
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
90
189
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
190
|
|
|
191
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
192
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
193
|
+
|
|
92
194
|
|
|
93
195
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
196
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -117,8 +219,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
117
219
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
220
|
|
|
119
221
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
-
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
-
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
222
|
|
|
123
223
|
|
|
124
224
|
def test_invalid_attribute_raises_error():
|
|
@@ -134,6 +234,7 @@ def test_dir_returns_all_env_vars():
|
|
|
134
234
|
assert "JAX_PLATFORMS" in env_vars
|
|
135
235
|
assert "TPU_NAME" in env_vars
|
|
136
236
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
237
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
137
238
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
239
|
|
|
139
240
|
|
|
@@ -141,11 +242,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
141
242
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
243
|
assert envs.TPU_WORKER_ID == "0"
|
|
143
244
|
|
|
144
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
145
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
146
|
-
|
|
147
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
-
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
245
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
246
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
149
247
|
|
|
150
248
|
|
|
151
249
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
@@ -158,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
158
256
|
|
|
159
257
|
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
258
|
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
-
assert envs.MODEL_IMPL_TYPE == "
|
|
259
|
+
assert envs.MODEL_IMPL_TYPE == "auto"
|
|
162
260
|
|
|
163
261
|
|
|
164
262
|
def test_cache_preserves_values_across_env_changes(
|
tests/test_quantization.py
CHANGED
|
@@ -112,6 +112,8 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
112
112
|
self.mesh = Mesh(jax.devices(), ('model', ))
|
|
113
113
|
self.rng = jax.random.PRNGKey(0)
|
|
114
114
|
self.model = SimpleModel(rngs=nnx.Rngs(0))
|
|
115
|
+
self.model.vllm_config = MagicMock()
|
|
116
|
+
self.model.vllm_config.model_config.use_mla = False
|
|
115
117
|
|
|
116
118
|
self.qwix_config = [
|
|
117
119
|
{
|
|
@@ -131,6 +133,7 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
131
133
|
"""Test that qwix.quantize_model is called with the correct arguments."""
|
|
132
134
|
quantized_model_mock = MagicMock(spec=nnx.Module)
|
|
133
135
|
mock_quantize_model.return_value = quantized_model_mock
|
|
136
|
+
self.model.vllm_config.sharding_config.total_dp_size = 1
|
|
134
137
|
|
|
135
138
|
with patch(
|
|
136
139
|
"tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
|
tests/test_utils.py
CHANGED
|
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
231
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
232
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
233
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
235
235
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
-
assert get_jax_dtype_from_str_dtype("auto") is None
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,21 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
# The environment variables override should be imported before any other
|
|
4
2
|
# modules to ensure that the environment variables are set before any
|
|
5
3
|
# other modules are imported.
|
|
6
4
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
+
from tpu_inference import envs
|
|
7
6
|
from tpu_inference import tpu_info as ti
|
|
8
7
|
from tpu_inference.logger import init_logger
|
|
9
8
|
|
|
10
9
|
logger = init_logger(__name__)
|
|
11
10
|
|
|
12
|
-
if "proxy" in
|
|
11
|
+
if "proxy" in envs.JAX_PLATFORMS:
|
|
13
12
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
13
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
14
|
try:
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
16
17
|
import pathwaysutils
|
|
18
|
+
import vllm
|
|
19
|
+
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
+
resolve_obj_by_qualname)
|
|
17
21
|
pathwaysutils.initialize()
|
|
18
22
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
+
|
|
24
|
+
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
+
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
+
# global topology discovery across multiple hosts, the platform
|
|
27
|
+
# resolution must happen before other components are loaded.
|
|
28
|
+
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
+
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
+
platform_cls_qualname)()
|
|
32
|
+
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
+
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
+
logger.info(
|
|
35
|
+
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
+
)
|
|
37
|
+
|
|
19
38
|
except Exception as e:
|
|
20
39
|
logger.error(
|
|
21
40
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Tuple
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
DECODE_SLICES = 'DECODE_SLICES'
|
|
5
|
+
from tpu_inference import envs
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def is_disagg_enabled() -> bool:
|
|
11
9
|
# We triggrer our code path as long as prefill slices are set. This
|
|
12
10
|
# allows us to test interleave mode effectively with the code path
|
|
13
11
|
# for comparison purposes.
|
|
14
|
-
return PREFILL_SLICES
|
|
12
|
+
return bool(envs.PREFILL_SLICES)
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
-
if
|
|
41
|
+
if not envs.PREFILL_SLICES:
|
|
44
42
|
return ()
|
|
45
|
-
return _parse_slices(
|
|
43
|
+
return _parse_slices(envs.PREFILL_SLICES)
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
-
if
|
|
47
|
+
if not envs.DECODE_SLICES:
|
|
50
48
|
return ()
|
|
51
|
-
return _parse_slices(
|
|
49
|
+
return _parse_slices(envs.DECODE_SLICES)
|
|
@@ -60,7 +60,6 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
-
import os
|
|
64
63
|
import threading
|
|
65
64
|
import time
|
|
66
65
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
|
|
|
86
85
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
87
86
|
from vllm.v1.request import Request
|
|
88
87
|
|
|
88
|
+
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
445
|
-
"").lower() == "ray"
|
|
444
|
+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
446
445
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
447
446
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
448
447
|
# for TPU host topology.
|
|
@@ -458,7 +457,6 @@ class TPUConnectorWorker:
|
|
|
458
457
|
self.side_channel_port = get_side_channel_port()
|
|
459
458
|
|
|
460
459
|
self.kv_transfer_server = None
|
|
461
|
-
self._maybe_start_p2p_server()
|
|
462
460
|
self.zmq_cxt = zmq.Context()
|
|
463
461
|
if self.is_producer:
|
|
464
462
|
ready_event = threading.Event()
|
|
@@ -500,6 +498,7 @@ class TPUConnectorWorker:
|
|
|
500
498
|
self.shape = list(kv_layer.shape)
|
|
501
499
|
self.dtype = kv_layer.dtype
|
|
502
500
|
self.sharding = kv_layer.sharding
|
|
501
|
+
self._maybe_start_p2p_server()
|
|
503
502
|
|
|
504
503
|
def _maybe_start_p2p_server(self):
|
|
505
504
|
if self.kv_transfer_server is not None:
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
+
from tpu_inference import envs
|
|
5
6
|
from tpu_inference.logger import init_logger
|
|
6
7
|
|
|
7
8
|
logger = init_logger(__name__)
|
|
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_kv_ips() -> str:
|
|
20
|
-
if
|
|
21
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
21
22
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
22
23
|
ips = []
|
|
23
24
|
for node_id in range(num_nodes):
|
|
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
|
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_kv_ports() -> str:
|
|
31
|
-
if
|
|
32
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
32
33
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
33
34
|
ports = []
|
|
34
35
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -15,18 +15,93 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
-
|
|
18
|
+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
+
MODEL_IMPL_TYPE: str = "auto"
|
|
19
20
|
NEW_MODEL_DESIGN: bool = False
|
|
20
21
|
PHASED_PROFILING_DIR: str = ""
|
|
21
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
22
23
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
+
NUM_SLICES: int = 1
|
|
23
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
24
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
+
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def env_with_choices(
|
|
31
|
+
env_name: str,
|
|
32
|
+
default: str | None,
|
|
33
|
+
choices: list[str] | Callable[[], list[str]],
|
|
34
|
+
case_sensitive: bool = True,
|
|
35
|
+
) -> Callable[[], str | None]:
|
|
36
|
+
"""
|
|
37
|
+
Create a lambda that validates environment variable against allowed choices
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
env_name: Name of the environment variable
|
|
41
|
+
default: Default value if not set (can be None)
|
|
42
|
+
choices: List of valid string options or callable that returns list
|
|
43
|
+
case_sensitive: Whether validation should be case sensitive
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Lambda function for environment_variables dict
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def _get_validated_env() -> str | None:
|
|
50
|
+
value = os.getenv(env_name)
|
|
51
|
+
if value is None:
|
|
52
|
+
return default
|
|
53
|
+
|
|
54
|
+
# Resolve choices if it's a callable (for lazy loading)
|
|
55
|
+
actual_choices = choices() if callable(choices) else choices
|
|
56
|
+
|
|
57
|
+
if not case_sensitive:
|
|
58
|
+
check_value = value.lower()
|
|
59
|
+
check_choices = [choice.lower() for choice in actual_choices]
|
|
60
|
+
else:
|
|
61
|
+
check_value = value
|
|
62
|
+
check_choices = actual_choices
|
|
63
|
+
|
|
64
|
+
if check_value not in check_choices:
|
|
65
|
+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
66
|
+
f"Valid options: {actual_choices}.")
|
|
67
|
+
|
|
68
|
+
return value
|
|
69
|
+
|
|
70
|
+
return _get_validated_env
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
+
"""
|
|
75
|
+
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
+
("true", "false", "True", "False").
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env_name: Name of the environment variable
|
|
80
|
+
default: Default boolean value if not set
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _get_bool_env() -> bool:
|
|
84
|
+
value = os.getenv(env_name)
|
|
85
|
+
if value is None or value == "":
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
value_lower = value.lower()
|
|
89
|
+
if value_lower in ("true", "1"):
|
|
90
|
+
return True
|
|
91
|
+
elif value_lower in ("false", "0"):
|
|
92
|
+
return False
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
+
|
|
98
|
+
return _get_bool_env
|
|
99
|
+
|
|
25
100
|
|
|
26
101
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
102
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
103
|
"JAX_PLATFORMS":
|
|
29
|
-
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
104
|
+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
|
|
30
105
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
31
106
|
"TPU_ACCELERATOR_TYPE":
|
|
32
107
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -38,7 +113,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
38
113
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
39
114
|
# Backend for multi-host communication on TPU
|
|
40
115
|
"TPU_MULTIHOST_BACKEND":
|
|
41
|
-
|
|
116
|
+
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
|
|
42
117
|
# Slice configuration for disaggregated prefill workers
|
|
43
118
|
"PREFILL_SLICES":
|
|
44
119
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -47,28 +122,37 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
47
122
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
48
123
|
# Skip JAX precompilation step during initialization
|
|
49
124
|
"SKIP_JAX_PRECOMPILE":
|
|
50
|
-
|
|
125
|
+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
|
|
126
|
+
# Check for XLA recompilation during execution
|
|
127
|
+
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
128
|
+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
|
|
51
129
|
# Model implementation type (e.g., "flax_nnx")
|
|
52
130
|
"MODEL_IMPL_TYPE":
|
|
53
|
-
|
|
131
|
+
env_with_choices("MODEL_IMPL_TYPE", "auto",
|
|
132
|
+
["auto", "vllm", "flax_nnx", "jetpack"]),
|
|
54
133
|
# Enable new experimental model design
|
|
55
134
|
"NEW_MODEL_DESIGN":
|
|
56
|
-
|
|
135
|
+
env_bool("NEW_MODEL_DESIGN", default=False),
|
|
57
136
|
# Directory to store phased profiling output
|
|
58
137
|
"PHASED_PROFILING_DIR":
|
|
59
138
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
60
139
|
# Python tracer level for profiling
|
|
61
140
|
"PYTHON_TRACER_LEVEL":
|
|
62
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
141
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
63
142
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
64
143
|
"USE_MOE_EP_KERNEL":
|
|
65
|
-
|
|
144
|
+
env_bool("USE_MOE_EP_KERNEL", default=False),
|
|
145
|
+
# Number of TPU slices for multi-slice mesh
|
|
146
|
+
"NUM_SLICES":
|
|
147
|
+
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
66
148
|
# Enable/disable Ray usage statistics collection
|
|
67
149
|
"RAY_USAGE_STATS_ENABLED":
|
|
68
150
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
69
151
|
# Ray compiled DAG channel type for TPU
|
|
70
152
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
71
|
-
|
|
153
|
+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
+
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
+
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
72
156
|
}
|
|
73
157
|
|
|
74
158
|
|
|
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
+
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
+
self.vllm_config.ec_transfer_config is None
|
|
113
|
+
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
111
114
|
|
|
112
115
|
def _initialize_ray_cluster(self) -> None:
|
|
113
116
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -133,10 +136,14 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
133
136
|
|
|
134
137
|
pp_size = self.parallel_config.pipeline_parallel_size
|
|
135
138
|
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
|
|
140
|
+
ray_nodes = ray.nodes()
|
|
141
|
+
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
142
|
+
|
|
136
143
|
if pp_size == 1:
|
|
137
144
|
placement_group_specs = [{
|
|
138
145
|
device_str: node['Resources'][device_str]
|
|
139
|
-
} for node in
|
|
146
|
+
} for node in ray_nodes]
|
|
140
147
|
else:
|
|
141
148
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
142
149
|
placement_group_specs = [{
|
|
@@ -352,7 +359,7 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
352
359
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
353
360
|
self.collective_rpc("init_device")
|
|
354
361
|
if self.parallel_config.pipeline_parallel_size > 1:
|
|
355
|
-
self.
|
|
362
|
+
self.collective_rpc("initialize_pp_transfer_connect")
|
|
356
363
|
self.collective_rpc("load_model")
|
|
357
364
|
|
|
358
365
|
if self.use_ray_spmd_worker:
|