hud-python 0.4.27__py3-none-any.whl → 0.4.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +73 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +65 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +563 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +348 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/tests/test_native_init.py +1 -1
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +31 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +586 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/computer/hud.py +4 -4
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/METADATA +20 -2
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/RECORD +67 -47
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/WHEEL +0 -0
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.27.dist-info → hud_python-0.4.29.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
{% set image_count = namespace(value=0) %}
|
|
2
|
+
{% set video_count = namespace(value=0) %}
|
|
3
|
+
{{- '<|im_start|>system\n' }}
|
|
4
|
+
{%- if messages[0]['role'] == 'system' -%}
|
|
5
|
+
{%- if messages[0]['content'] is string -%}
|
|
6
|
+
{{ messages[0]['content'] }}
|
|
7
|
+
{%- else -%}
|
|
8
|
+
{%- for content in messages[0]['content'] -%}
|
|
9
|
+
{%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
|
10
|
+
{%- set image_count.value = image_count.value + 1 -%}
|
|
11
|
+
{%- if add_vision_id -%}
|
|
12
|
+
{{ 'Picture ' ~ image_count.value ~ ': ' }}
|
|
13
|
+
{%- endif -%}
|
|
14
|
+
{{ '<|vision_start|><|image_pad|><|vision_end|>' }}
|
|
15
|
+
{%- elif content['type'] == 'video' or 'video' in content -%}
|
|
16
|
+
{%- set video_count.value = video_count.value + 1 -%}
|
|
17
|
+
{%- if add_vision_id -%}
|
|
18
|
+
{{ 'Video ' ~ video_count.value ~ ': ' }}
|
|
19
|
+
{%- endif -%}
|
|
20
|
+
{{ '<|vision_start|><|video_pad|><|vision_end|>' }}
|
|
21
|
+
{%- elif 'text' in content -%}
|
|
22
|
+
{{ content['text'] }}
|
|
23
|
+
{%- endif -%}
|
|
24
|
+
{%- endfor -%}
|
|
25
|
+
{%- endif -%}
|
|
26
|
+
{%- else -%}
|
|
27
|
+
{{ 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
|
|
28
|
+
{%- endif -%}
|
|
29
|
+
{%- if tools -%}
|
|
30
|
+
{{ '\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n' }}
|
|
31
|
+
{{- tools | map('tojson') | join('\n') -}}
|
|
32
|
+
{{ '\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>' }}
|
|
33
|
+
{%- endif -%}
|
|
34
|
+
{{ '<|im_end|>\n' }}
|
|
35
|
+
{%- for message in messages -%}
|
|
36
|
+
{# Skip the first system message as it was already rendered. #}
|
|
37
|
+
{%- if loop.first and message.role == 'system' %}{% continue %}{% endif -%}
|
|
38
|
+
|
|
39
|
+
{# Render tool messages. The logic is slightly different with other messages. #}
|
|
40
|
+
{%- if message['role'] == 'tool' -%}
|
|
41
|
+
{%- if loop.first or messages[loop.index0 - 1]['role'] != 'tool' -%}
|
|
42
|
+
{{ '<|im_start|>user' }}
|
|
43
|
+
{%- endif -%}
|
|
44
|
+
{{ '\n<tool_response>\n' }}
|
|
45
|
+
{%- else -%}
|
|
46
|
+
{{ '<|im_start|>' ~ message['role'] ~ '\n' }}
|
|
47
|
+
{%- endif -%}
|
|
48
|
+
|
|
49
|
+
{%- if message['content'] is string -%}
|
|
50
|
+
{{ message['content'] }}
|
|
51
|
+
{%- else -%}
|
|
52
|
+
{%- for content in message['content'] -%}
|
|
53
|
+
{%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
|
54
|
+
{%- set image_count.value = image_count.value + 1 -%}
|
|
55
|
+
{%- if add_vision_id -%}
|
|
56
|
+
{{ 'Picture ' ~ image_count.value ~ ': ' }}
|
|
57
|
+
{%- endif -%}
|
|
58
|
+
{{ '<|vision_start|><|image_pad|><|vision_end|>' }}
|
|
59
|
+
{%- elif content['type'] == 'video' or 'video' in content -%}
|
|
60
|
+
{%- set video_count.value = video_count.value + 1 -%}
|
|
61
|
+
{%- if add_vision_id -%}
|
|
62
|
+
{{ 'Video ' ~ video_count.value ~ ': ' }}
|
|
63
|
+
{%- endif -%}
|
|
64
|
+
{{ '<|vision_start|><|video_pad|><|vision_end|>' }}
|
|
65
|
+
{%- elif 'text' in content and message['role'] == 'assistant' -%}
|
|
66
|
+
{% generation %} {{ content['text'] }} {% endgeneration %}
|
|
67
|
+
{%- elif 'text' in content -%}
|
|
68
|
+
{{ content['text'] }}
|
|
69
|
+
{%- endif -%}
|
|
70
|
+
{%- endfor -%}
|
|
71
|
+
{%- endif -%}
|
|
72
|
+
{# Render tool_calls in AI messages. #}
|
|
73
|
+
{%- if message['role'] == 'assistant' and 'tool_calls' in message -%}
|
|
74
|
+
{# It will be cleaner if I can use some map function and join them with '\n' #}
|
|
75
|
+
{%- for tool_call in message['tool_calls'] -%}
|
|
76
|
+
{%- if tool_call['function'] is defined -%}
|
|
77
|
+
{%- set tool_call = tool_call['function'] -%}
|
|
78
|
+
{%- endif -%}
|
|
79
|
+
{# Handle the case where arguments is already a JSON string (OpenAI format) #}
|
|
80
|
+
{%- if tool_call.arguments is string -%}
|
|
81
|
+
{% generation %} {{ '<tool_call>\n{"name": "' }}{{ tool_call.name }}{{ '", "arguments": ' }}{{ tool_call.arguments }}{{ '}\n</tool_call>' }} {% endgeneration %}
|
|
82
|
+
{%- else -%}
|
|
83
|
+
{% generation %} {{ '<tool_call>\n' }}{{ tool_call | tojson }}{{ '\n</tool_call>' }} {% endgeneration %}
|
|
84
|
+
{%- endif -%}
|
|
85
|
+
{%- if not loop.last -%}
|
|
86
|
+
{% generation %} {{ '\n' }} {% endgeneration %}
|
|
87
|
+
{%- endif -%}
|
|
88
|
+
{%- endfor -%}
|
|
89
|
+
{%- endif -%}
|
|
90
|
+
{%- if message['role'] == 'tool' -%}
|
|
91
|
+
{{ '\n</tool_response>' }}
|
|
92
|
+
{%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' -%}
|
|
93
|
+
{{ '<|im_end|>\n' }}
|
|
94
|
+
{%- endif -%}
|
|
95
|
+
{%- else -%}
|
|
96
|
+
{{ '<|im_end|>\n' }}
|
|
97
|
+
{%- endif -%}
|
|
98
|
+
{%- endfor -%}
|
|
99
|
+
{%- if add_generation_prompt -%}
|
|
100
|
+
{{ '<|im_start|>assistant\n' }}
|
|
101
|
+
{%- endif -%}
|
hud/rl/config.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Configuration for RL training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
# List of supported VL (Vision-Language) models
|
|
9
|
+
SUPPORTED_MODELS = [
|
|
10
|
+
"Qwen/Qwen2.5-VL-3B-Instruct",
|
|
11
|
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
|
12
|
+
"Qwen/Qwen2.5-VL-14B-Instruct",
|
|
13
|
+
"Qwen/Qwen2.5-VL-32B-Instruct",
|
|
14
|
+
"Qwen/Qwen2.5-VL-72B-Instruct",
|
|
15
|
+
"Qwen/Qwen2.5-7B-Instruct",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def validate_vl_model(model_name: str) -> None:
|
|
20
|
+
"""Validate that the model is a supported VL model.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model_name: The model name to validate
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If the model is not a supported VL model
|
|
27
|
+
"""
|
|
28
|
+
if not any(model_name.startswith(supported) for supported in SUPPORTED_MODELS):
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"Model '{model_name}' is not a supported VL model. "
|
|
31
|
+
f"Only VL (Vision-Language) models are supported for RL training.\n"
|
|
32
|
+
f"Supported models: {', '.join(SUPPORTED_MODELS)}\n"
|
|
33
|
+
f"Note: '{model_name}' appears to be a text-only model."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ModelConfig:
|
|
39
|
+
"""Model and LoRA configuration."""
|
|
40
|
+
|
|
41
|
+
base_model: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
42
|
+
lora_r: int = 8
|
|
43
|
+
lora_alpha: int = 16
|
|
44
|
+
lora_dropout: float = 0.05
|
|
45
|
+
target_modules: tuple[str, ...] = (
|
|
46
|
+
"q_proj",
|
|
47
|
+
"k_proj",
|
|
48
|
+
"v_proj",
|
|
49
|
+
"o_proj",
|
|
50
|
+
"gate_proj",
|
|
51
|
+
"up_proj",
|
|
52
|
+
"down_proj",
|
|
53
|
+
)
|
|
54
|
+
min_pixels: int = 256 * 28 * 28
|
|
55
|
+
max_pixels: int = 512 * 28 * 28
|
|
56
|
+
attn_implementation: str = "flash_attention_2"
|
|
57
|
+
use_liger: bool = True
|
|
58
|
+
gradient_checkpointing: bool = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class TrainingConfig:
|
|
63
|
+
"""Training hyperparameters."""
|
|
64
|
+
|
|
65
|
+
# Training parameters
|
|
66
|
+
training_steps: int = 100
|
|
67
|
+
shuffle_dataset: bool = False
|
|
68
|
+
save_every_batches: int = 1
|
|
69
|
+
|
|
70
|
+
# Batching parameters
|
|
71
|
+
epochs: int = 2
|
|
72
|
+
batch_size: int = 24
|
|
73
|
+
group_size: int = 4
|
|
74
|
+
mini_batch_size: int = 1
|
|
75
|
+
update_after_group: bool = True # Whether to update the policy after each task group
|
|
76
|
+
accumulate_over_minibatches: bool = False # Whether to accumulate over minibatches
|
|
77
|
+
|
|
78
|
+
# Advantage calculation parameters
|
|
79
|
+
batch_level: Literal["group", "batch"] = "group"
|
|
80
|
+
no_std: bool = False
|
|
81
|
+
leave_one_out: bool = True
|
|
82
|
+
|
|
83
|
+
# Replay buffer parameters
|
|
84
|
+
buffer_steps: int = 4
|
|
85
|
+
select_strategy: Literal["recent", "variance", "random"] = "variance"
|
|
86
|
+
|
|
87
|
+
# Aggregation parameters
|
|
88
|
+
ppo_mode: Literal["per_token", "per_trace"] = "per_token"
|
|
89
|
+
token_agg: Literal["mean", "sum"] = "mean" # noqa: S105
|
|
90
|
+
|
|
91
|
+
# Regularization parameters
|
|
92
|
+
kl_beta: float = 0.0
|
|
93
|
+
entropy_beta: float = 0.0
|
|
94
|
+
top_eps: float = 0.2
|
|
95
|
+
bottom_eps: float = 0.1
|
|
96
|
+
|
|
97
|
+
# Training hyperparameters
|
|
98
|
+
lr: float = 3e-5
|
|
99
|
+
grad_clip: float = 1.0
|
|
100
|
+
|
|
101
|
+
# Adam hyperparameters
|
|
102
|
+
use_8bit_optimizer: bool = True
|
|
103
|
+
adam_betas: tuple[float, float] = (0.9, 0.999)
|
|
104
|
+
adam_eps: float = 1e-8
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class ActorConfig:
|
|
109
|
+
"""Actor/episode collection configuration."""
|
|
110
|
+
|
|
111
|
+
# Execution parameters
|
|
112
|
+
max_steps_per_episode: int = 5
|
|
113
|
+
max_parallel_episodes: int = 48
|
|
114
|
+
max_new_tokens: int = 1024
|
|
115
|
+
force_tool_choice: bool = True
|
|
116
|
+
allowed_tools: list[str] | None = None
|
|
117
|
+
|
|
118
|
+
# Model parameters
|
|
119
|
+
temperature: float = 0.7
|
|
120
|
+
|
|
121
|
+
# Hud agent parameters
|
|
122
|
+
system_prompt: str = "You are an expert agent. Complete the task efficiently."
|
|
123
|
+
vllm_base_url: str = "http://localhost:8000/v1"
|
|
124
|
+
vllm_api_key: str = "token-abc123"
|
|
125
|
+
|
|
126
|
+
# Episode execution timeout (seconds)
|
|
127
|
+
episode_timeout_sec: int = 600
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class Config:
|
|
132
|
+
"""Main configuration combining all sub-configs."""
|
|
133
|
+
|
|
134
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
135
|
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
|
136
|
+
actor: ActorConfig = field(default_factory=ActorConfig)
|
|
137
|
+
|
|
138
|
+
# Telemetry configuration
|
|
139
|
+
job_name: str = "RL Training"
|
|
140
|
+
job_id: str | None = None # Use existing job ID if provided
|
|
141
|
+
stats_interval: int = 1
|
|
142
|
+
verbose: bool = False
|
|
143
|
+
|
|
144
|
+
# Paths
|
|
145
|
+
out_dir: str = "./checkpoints"
|
|
146
|
+
adapter_prefix: str = "cua-grpo-step"
|
|
147
|
+
|
|
148
|
+
# Misc
|
|
149
|
+
seed: int = 1234
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def from_dict(cls, d: dict) -> Config:
|
|
153
|
+
"""Create config from dictionary."""
|
|
154
|
+
model = ModelConfig(**d.get("model", {}))
|
|
155
|
+
training = TrainingConfig(**d.get("training", {}))
|
|
156
|
+
actor = ActorConfig(**d.get("actor", {}))
|
|
157
|
+
|
|
158
|
+
return cls(
|
|
159
|
+
model=model,
|
|
160
|
+
training=training,
|
|
161
|
+
actor=actor,
|
|
162
|
+
job_name=d.get("job_name", "RL Training"),
|
|
163
|
+
job_id=d.get("job_id"),
|
|
164
|
+
stats_interval=d.get("stats_interval", 1),
|
|
165
|
+
verbose=d.get("verbose", False),
|
|
166
|
+
out_dir=d.get("out_dir", "./checkpoints"),
|
|
167
|
+
adapter_prefix=d.get("adapter_prefix", "cua-grpo-step"),
|
|
168
|
+
seed=d.get("seed", 1234),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def to_dict(self) -> dict:
|
|
172
|
+
"""Convert config to dictionary."""
|
|
173
|
+
return {
|
|
174
|
+
"model": self.model.__dict__,
|
|
175
|
+
"training": self.training.__dict__,
|
|
176
|
+
"actor": self.actor.__dict__,
|
|
177
|
+
"job_name": self.job_name,
|
|
178
|
+
"job_id": self.job_id,
|
|
179
|
+
"stats_interval": self.stats_interval,
|
|
180
|
+
"verbose": self.verbose,
|
|
181
|
+
"out_dir": self.out_dir,
|
|
182
|
+
"adapter_prefix": self.adapter_prefix,
|
|
183
|
+
"seed": self.seed,
|
|
184
|
+
}
|
hud/rl/distributed.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Distributed training utilities for GRPO."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def setup_distributed() -> None:
|
|
13
|
+
"""Initialize distributed training environment."""
|
|
14
|
+
if "RANK" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
|
|
15
|
+
# Set device for this process
|
|
16
|
+
local_rank = int(os.environ["LOCAL_RANK"])
|
|
17
|
+
torch.cuda.set_device(local_rank)
|
|
18
|
+
|
|
19
|
+
# Initialize process group
|
|
20
|
+
dist.init_process_group("nccl")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_local_rank() -> int:
|
|
24
|
+
"""Get local rank from environment."""
|
|
25
|
+
return int(os.environ.get("LOCAL_RANK", 0))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_global_rank() -> int:
|
|
29
|
+
"""Get global rank from environment."""
|
|
30
|
+
return int(os.environ.get("RANK", 0))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_world_size() -> int:
|
|
34
|
+
"""Get world size from environment."""
|
|
35
|
+
return int(os.environ.get("WORLD_SIZE", 1))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def cleanup_distributed() -> None:
|
|
39
|
+
"""Clean up distributed environment."""
|
|
40
|
+
if dist.is_initialized():
|
|
41
|
+
dist.destroy_process_group()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def is_main_process() -> bool:
|
|
45
|
+
"""Check if this is the main process (rank 0)."""
|
|
46
|
+
if not dist.is_initialized():
|
|
47
|
+
return True
|
|
48
|
+
return dist.get_rank() == 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def synchronize() -> None:
|
|
52
|
+
"""Synchronize all processes."""
|
|
53
|
+
if dist.is_initialized():
|
|
54
|
+
dist.barrier()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
"""Average a tensor across all processes."""
|
|
59
|
+
if not dist.is_initialized():
|
|
60
|
+
return tensor
|
|
61
|
+
|
|
62
|
+
world_size = dist.get_world_size()
|
|
63
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
|
64
|
+
tensor /= world_size
|
|
65
|
+
return tensor
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def broadcast_object(obj: Any, src: int = 0) -> Any:
|
|
69
|
+
"""Broadcast a Python object from src rank to all ranks."""
|
|
70
|
+
if not dist.is_initialized():
|
|
71
|
+
return obj
|
|
72
|
+
|
|
73
|
+
obj_list = [obj] if dist.get_rank() == src else [None]
|
|
74
|
+
dist.broadcast_object_list(obj_list, src=src)
|
|
75
|
+
return obj_list[0]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def gather_tensors(tensor: torch.Tensor) -> list[torch.Tensor] | None:
|
|
79
|
+
"""Gather tensors from all ranks to rank 0.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of tensors on rank 0, None on other ranks
|
|
83
|
+
"""
|
|
84
|
+
if not dist.is_initialized():
|
|
85
|
+
return [tensor]
|
|
86
|
+
|
|
87
|
+
world_size = dist.get_world_size()
|
|
88
|
+
|
|
89
|
+
if dist.get_rank() == 0:
|
|
90
|
+
gathered = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
91
|
+
dist.gather(tensor, gathered, dst=0)
|
|
92
|
+
return gathered
|
|
93
|
+
else:
|
|
94
|
+
dist.gather(tensor, None, dst=0)
|
|
95
|
+
return None
|