megatron-fsdp 0.2.0.dev112119__tar.gz → 0.2.0.dev112613__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/PKG-INFO +1 -1
  2. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/megatron_fsdp.py +4 -3
  3. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/package_info.py +1 -1
  4. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/param_and_grad_buffer.py +3 -0
  5. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/uneven_dtensor.py +10 -1
  6. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/utils.py +3 -6
  7. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp.egg-info/PKG-INFO +1 -1
  8. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/README.md +0 -0
  9. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/__init__.py +0 -0
  10. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/distributed_data_parallel_config.py +0 -0
  11. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp/fully_shard.py +0 -0
  12. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp.egg-info/SOURCES.txt +0 -0
  13. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp.egg-info/dependency_links.txt +0 -0
  14. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp.egg-info/requires.txt +0 -0
  15. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/megatron_fsdp.egg-info/top_level.txt +0 -0
  16. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/pyproject.toml +0 -0
  17. {megatron_fsdp-0.2.0.dev112119 → megatron_fsdp-0.2.0.dev112613}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-fsdp
3
- Version: 0.2.0.dev112119
3
+ Version: 0.2.0.dev112613
4
4
  Summary: **Megatron-FSDP** is an NVIDIA-developed PyTorch extension that provides a high-performance implementation of Fully Sharded Data Parallelism (FSDP)
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -898,9 +898,10 @@ class MegatronFSDP(torch.nn.Module):
898
898
 
899
899
  # Register pre state_dict hook to ensure that the module parameters are
900
900
  # distributed before saving the state_dict.
901
- self._state_dict_pre_hook = self.module.register_state_dict_pre_hook(
902
- lambda *args, **kwargs: self._replace_param_with_distributed_if_needed()
903
- )
901
+ for name, module in self.named_modules():
902
+ module.register_state_dict_pre_hook(
903
+ lambda *args, **kwargs: self._replace_param_with_distributed_if_needed()
904
+ )
904
905
 
905
906
  @contextmanager
906
907
  def no_sync(self):
@@ -4,7 +4,7 @@
4
4
  MAJOR = 0
5
5
  MINOR = 2
6
6
  PATCH = 0
7
- PRE_RELEASE = '0.dev112119'
7
+ PRE_RELEASE = '0.dev112613'
8
8
 
9
9
  # Use the following formatting: (major, minor, patch, pre-release)
10
10
  VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
@@ -2782,6 +2782,9 @@ class GradReducePipeline:
2782
2782
  outer_fsdp_group_grad_reduce (bool, optional): Whether to reduce gradients
2783
2783
  across outer-DP groups. Defaults to False.
2784
2784
  """
2785
+ # Sort parameters by their bucket IDs to ensure a deterministic processing order.
2786
+ # Performing reduce-scatter operations out of order can lead to hangs.
2787
+ params = sorted(list(params), key=lambda x: self.buffer.param_to_param_group[x])
2785
2788
  for param in params:
2786
2789
  bucket_id = self.buffer.param_to_param_group[param]
2787
2790
  param_group = self.buffer.parameter_groups[bucket_id]
@@ -25,6 +25,8 @@ from torch.distributed.checkpoint.metadata import (
25
25
  from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
26
26
  from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard
27
27
 
28
+ from .utils import get_mesh_names
29
+
28
30
 
29
31
  def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata:
30
32
  """
@@ -272,7 +274,14 @@ def gather_uneven_dtensor_to_full_tensor(
272
274
  if not device_mesh.mesh_dim_names:
273
275
  process_group = device_mesh.get_group()
274
276
  else:
275
- process_group = device_mesh._flatten().get_group()
277
+ # Check if the fully-flattened mesh exists first.
278
+ full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names)
279
+ if full_flattened_mesh_dim_name in get_mesh_names(device_mesh):
280
+ # Retrieve the existing flattened DeviceMesh ProcessGroup.
281
+ process_group = device_mesh[full_flattened_mesh_dim_name].get_group()
282
+ else:
283
+ # Create the _-separated flattened DeviceMesh ProcessGroup.
284
+ process_group = device_mesh._flatten().get_group()
276
285
 
277
286
  # Collect chunk metadata for uneven shards (update if missing)
278
287
  if not hasattr(dtensor._local_tensor, "__create_chunk_list__"):
@@ -167,13 +167,10 @@ def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]:
167
167
  submesh_dim_name
168
168
  for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
169
169
  for submesh_dim_name in (child_mesh.mesh_dim_names or [])
170
- if root_mesh == device_mesh
170
+ # Add flattened or other unaccounted for children of the root mesh.
171
+ if root_mesh == device_mesh and submesh_dim_name not in mesh_dim_names
171
172
  ]
172
- # Combine without duplicate dimensions.
173
- for dim_name in submesh_dim_names:
174
- if dim_name not in mesh_dim_names:
175
- mesh_dim_names.append(dim_name)
176
- return mesh_dim_names
173
+ return mesh_dim_names + submesh_dim_names
177
174
 
178
175
 
179
176
  def contains_submesh(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-fsdp
3
- Version: 0.2.0.dev112119
3
+ Version: 0.2.0.dev112613
4
4
  Summary: **Megatron-FSDP** is an NVIDIA-developed PyTorch extension that provides a high-performance implementation of Fully Sharded Data Parallelism (FSDP)
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>