rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  rslearn/__init__.py,sha256=fFmAen3vxZyosEfPbG0W46IttujYGVxzrGkJ0YutmmY,73
2
+ rslearn/arg_parser.py,sha256=mkiZCiomUI5GNjG1jfPuTJebGHFzXbyUqe0pPwS4lTA,2055
2
3
  rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
3
4
  rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
4
- rslearn/main.py,sha256=vnWGvfNgj0mlUFwmOo3_OWMs2-FG8Q41LgivBLmLitA,28829
5
+ rslearn/main.py,sha256=lAMcE4e3wCO2tVUq3bJl2oOHyztsTagtSNc0kJU7OZk,29266
5
6
  rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
7
  rslearn/config/__init__.py,sha256=Bhf2VVncdMYRC8Wfb4GsJJ13OAJYNCO_ODLSNTmBOHM,638
7
8
  rslearn/config/dataset.py,sha256=cR6Jd9ppzHgKHCteUsNapCcsJk4k5X90E71EHfbW7m0,21046
@@ -10,11 +11,11 @@ rslearn/data_sources/aws_landsat.py,sha256=GA9H04KagBDm-N37jFdh_aHCX2ZneVdnqT1SN
10
11
  rslearn/data_sources/aws_open_data.py,sha256=nU_D5cqc-wibxq4uyUNb0z-XD0Puf1gZ8v5FMiMAN5w,30258
11
12
  rslearn/data_sources/aws_sentinel1.py,sha256=cmf_ZcB7GCyFAdbwExeAwJIHqLL0JVoXtq5WcQ8UuiU,5197
12
13
  rslearn/data_sources/climate_data_store.py,sha256=Hct-0Ui-_CCQISOlzsqkK1dKz8684HRqfVUI-zXW2wA,11571
13
- rslearn/data_sources/copernicus.py,sha256=fIQXVDDfJwklgna6PPeCDOE6X6KSqw_XwaKcIbxJilI,36660
14
+ rslearn/data_sources/copernicus.py,sha256=BHPyeDLeCy1-Sjyhv84snW-TnZyNVnLyn4pjjZLfTzE,36652
14
15
  rslearn/data_sources/data_source.py,sha256=69ptYhqa6pnKM04ux9hWvTPExN_lFNuU_0t_seYFnHE,3916
15
- rslearn/data_sources/earthdaily.py,sha256=PDsbE47mhbzzeEGiXNjZlNt-qKHYGzAjJbZTIs-Halk,18110
16
+ rslearn/data_sources/earthdaily.py,sha256=dxOWm7Yiuh4fWVptRws_Ljh-HuNs1frf86ao91yS_80,19059
16
17
  rslearn/data_sources/earthdata_srtm.py,sha256=ysyVbVDLjhhLKdh7WKhQcwZezqvmTYaiPetTborW6zQ,11166
17
- rslearn/data_sources/gcp_public_data.py,sha256=O8v2iN4ym9Kwl4Zlw1FURPWnSoMD1drje8XH4xNQggE,36134
18
+ rslearn/data_sources/gcp_public_data.py,sha256=kr9stYo7ZCvz8s4E3wmoY-yAGZoLa_9RCwjS-Q5k9dM,36128
18
19
  rslearn/data_sources/geotiff.py,sha256=sFUp919chaX4j6lQytNp__xnMLlDI3Ac3rfB6F8sgZ0,45
19
20
  rslearn/data_sources/google_earth_engine.py,sha256=hpkt74ly2lEwjRrDp8FBmGvB3MEw_mQ38Av4rQOR3_w,24246
20
21
  rslearn/data_sources/local_files.py,sha256=-XyydSPtui1m49YuP7YrNKjM5DBWMf7YgpWE9uRcvrM,18365
@@ -25,7 +26,7 @@ rslearn/data_sources/planetary_computer.py,sha256=Wchr-OmAffuVteUW6VRofIqFpE-cJq
25
26
  rslearn/data_sources/raster_source.py,sha256=b8wo55GhVLxXwx1WYLzeRAlzD_ZkE_P9tnvUOdnsfQE,689
26
27
  rslearn/data_sources/usda_cdl.py,sha256=2_V11AhPRgLEGd4U5Pmx3UvE2HWBPbsFXhUIQVRVFeE,7138
27
28
  rslearn/data_sources/usgs_landsat.py,sha256=31GmOUfmxwTE6MTiVI4psb-ciVmunuA8cfvqDuvTHPE,19312
28
- rslearn/data_sources/utils.py,sha256=xeLQeUh--fjnfJuyC8nZPdRxMIQdnn6VLoFlMQ32hPE,12114
29
+ rslearn/data_sources/utils.py,sha256=oi2ybE423TLgpXlNjZ5qDQxDiwbSs7b-qD71UueQZHE,11327
29
30
  rslearn/data_sources/vector_source.py,sha256=NCa7CxIrGKe9yRT0NyyFKFQboDGDZ1h7663PV9OfMOM,44
30
31
  rslearn/data_sources/worldcereal.py,sha256=Psdf3EF3REj1WDltHWyMaICY3--KAJO_nEqpF0Gl_G8,21808
31
32
  rslearn/data_sources/worldcover.py,sha256=rimHJpQN9a56GaYxyHTOGXKzE3bkWKgd1UbH5A4aaGs,6097
@@ -59,7 +60,7 @@ rslearn/models/ssl4eo_s12.py,sha256=sOGEHcDo-rNdmEyoLu2AVEqfxRM_cv6zpfAmyn5c6tw,
59
60
  rslearn/models/swin.py,sha256=bMlGePXMFou4A_YSUZzjHgN9NniGXaCWdGQ31xHDKis,5511
60
61
  rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
61
62
  rslearn/models/terramind.py,sha256=kipar8sMaHJJ3b8vIgL0-s4qhHcA0Vb854vmlZ9cWh4,7524
62
- rslearn/models/trunk.py,sha256=jm5kXlHdFkUOuIlOSwfko93Luzefkug19hKYx5l617Y,10706
63
+ rslearn/models/trunk.py,sha256=H1QPQGAKsmocq3OiF66GW8MQI4LffupTDrgZR4Ta7QM,4708
63
64
  rslearn/models/unet.py,sha256=0xoKSsfG7y71lOqlx1F2G1H-4qq_ChjAuaAhNlTWIeo,5793
64
65
  rslearn/models/upsample.py,sha256=A0ppAFvoqSMMvESE5vxvA8giY8cToD8QoeMMPGk2tUg,965
65
66
  rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
@@ -70,33 +71,34 @@ rslearn/models/detr/matcher.py,sha256=4h_xFlgTMEJvJ6aLZUamrKZ72L5hDk9wPglNZ81JBg
70
71
  rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJTC608TaX6yY,3869
71
72
  rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
72
73
  rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
73
- rslearn/models/moe/distributed.py,sha256=CFlL8eC6I4LZynz5ydcs1Xy7BuLVuEfcqHupoiKLRQ0,7948
74
- rslearn/models/moe/soft.py,sha256=PJgifOWBvb-ltA0NJghsOtl3fDoixOz08ZQLqtEdndU,21411
75
74
  rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_Qk,2050
76
75
  rslearn/tile_stores/default.py,sha256=DEZmji2iLEVgI3abXwne6tb4C1qVtx_CaLxTpswfUV4,13852
77
76
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
78
77
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
79
- rslearn/train/data_module.py,sha256=bmwMNmpNCD_6mUHoS8TGJ_6ogPD79YSaN6CWYG1Cu90,22028
80
- rslearn/train/dataset.py,sha256=K9ncH980rvKl08PLzzXNCd4PxZRvhSu_r7SfI0R5kAI,45217
81
- rslearn/train/lightning_module.py,sha256=5wzXsXf_N3r8s_qbgxZwiDj-UWkRgviuVrmbdzPZDvg,14397
78
+ rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
79
+ rslearn/train/dataset.py,sha256=YiskNlYYcKqZxyw0Xzop1RGLbjMc-oK_rmhrSMVbTQg,51857
80
+ rslearn/train/lightning_module.py,sha256=ge2z8trU7cMvxBeqUXC1tB44pftzitw7DRsIa6asBS4,14623
82
81
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
83
- rslearn/train/prediction_writer.py,sha256=jiJChKmP6ZWylS3ElRVlfYBNANwYkKFp_wmTHXf5OTA,13012
84
- rslearn/train/scheduler.py,sha256=MBMv3TEtjEJuHZGNtP_qXKk9UenQg-JKdmbkHEV3jsc,1850
82
+ rslearn/train/prediction_writer.py,sha256=YNs92QqPrqbREZXoE-aPa_oKQW0C9LvZAY129vyvI08,13288
83
+ rslearn/train/scheduler.py,sha256=wFbmycMHgL6nRYeYalDjb0G8YVo8VD3T3sABS61jJ7c,2318
85
84
  rslearn/train/callbacks/__init__.py,sha256=VNV0ArZyYMvl3dGK2wl6F046khYJ1dEBlJS6G_SYNm0,47
86
- rslearn/train/callbacks/freeze_unfreeze.py,sha256=vuNLGyTLxyhR2Ih-GEoH-CDwFOBwBgm3yWEpn9rusBU,3766
87
- rslearn/train/callbacks/gradients.py,sha256=nhbA5f4QUsVfa-oV9Zj2xKk7GwWu5ejplhtIZEkRDng,4390
85
+ rslearn/train/callbacks/adapters.py,sha256=yfv8nyCj3jmo2_dNkFrjukKxh0MHsf2xKqWwMF0QUtY,1869
86
+ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjTZZEPBQaSul6Zc,17418
87
+ rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
88
88
  rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
89
89
  rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
90
90
  rslearn/train/tasks/classification.py,sha256=DI0_Wzs-9rNPWokvfxi1BIA6QyqNee42SpptQx82WHM,13182
91
91
  rslearn/train/tasks/detection.py,sha256=OoZzC8ZbmhyZ30tD-4cB-3Jj0AN6Y7hg0wk27rDguCE,22297
92
92
  rslearn/train/tasks/multi_task.py,sha256=dBWsnbvQ0CReNsbDHmZ_-sXjUE0H4S2OPcbJwMquG9g,6016
93
- rslearn/train/tasks/regression.py,sha256=mnlKT6zdf1fFfuOXDIAzGFZs_uU1szIoYsGuZcJJflQ,11444
94
- rslearn/train/tasks/segmentation.py,sha256=-CIMLYseHmuNkzrVtahILD4iMtdtN_fvW2-NcfiGq0U,20381
93
+ rslearn/train/tasks/per_pixel_regression.py,sha256=tkVntKFzPlWFxdupPlMfhIRWlJ0UCgxg_FGhcA2-wjE,8649
94
+ rslearn/train/tasks/regression.py,sha256=_PoxOfWNseujD4IWsuTL82fAAXgtco4WdfkNXQ68Nbg,11497
95
+ rslearn/train/tasks/segmentation.py,sha256=xEni3CLDyetviv84XrpJg5xeJU87WHGFKTVfIeemGIY,21868
95
96
  rslearn/train/tasks/task.py,sha256=4w2xKL_U5JAtdj2dYoVv82h6xTtgUsA3IvIOcXyZecs,3887
96
97
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
97
98
  rslearn/train/transforms/concatenate.py,sha256=sdVLJIyr9Nj2tzXEzvWFQnjJjyRSuhR_Faf6UlMIvbg,1568
98
99
  rslearn/train/transforms/crop.py,sha256=4jA3JJsC0ghicPHbfsNJ0d3WpChyvftY73ONiwQaif0,4214
99
100
  rslearn/train/transforms/flip.py,sha256=lkTeje3T8gNn2gt6957morXq1fGNho-apSpCvNp0_9o,3480
101
+ rslearn/train/transforms/mask.py,sha256=pwt33XXWLwldLiar-PgVgBQzQd1qfL18SPz3LYQMoYM,2111
100
102
  rslearn/train/transforms/normalize.py,sha256=zYLqcDQcrjukjf5XrbFmS990PK1WQMSmHqQZKa_T040,3514
101
103
  rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
102
104
  rslearn/train/transforms/transform.py,sha256=8Q-dPrmDr0tJ9ZOwjWqWK8kbnKi4uLxEnS9Nwf6BVJk,3594
@@ -104,7 +106,7 @@ rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
104
106
  rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
105
107
  rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
106
108
  rslearn/utils/fsspec.py,sha256=9QwN46heBhjUnth3qFeRNE3W6Wlr6dM3twYVswPnS9o,5300
107
- rslearn/utils/geometry.py,sha256=DTmDXNKsAaOOtm8laWELcYL6ckxBbgFJddeyGE0OMiw,15971
109
+ rslearn/utils/geometry.py,sha256=kE4RP1g2QLcbUVF329CUSBUYHdPbpnnMuzSqLMOXvLQ,15955
108
110
  rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF9Ps,3599
109
111
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
110
112
  rslearn/utils/jsonargparse.py,sha256=JcTKQoZ6jgwag-kSeTIEVBO9AsRj0X1oEJBsoaCazH4,658
@@ -115,9 +117,9 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
115
117
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
116
118
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
117
119
  rslearn/utils/vector_format.py,sha256=XggLCIUQBZWhOXWjvhrxBOHULpsXCbktm804DSAAink,15167
118
- rslearn-0.0.3.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
119
- rslearn-0.0.3.dist-info/METADATA,sha256=5aejNuiFYOCf7CNX-8TDnSSCd6f_Mddw1JO2SKN1cVE,31707
120
- rslearn-0.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
121
- rslearn-0.0.3.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
122
- rslearn-0.0.3.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
123
- rslearn-0.0.3.dist-info/RECORD,,
120
+ rslearn-0.0.4.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
121
+ rslearn-0.0.4.dist-info/METADATA,sha256=_UtS9N1YTE0lKlPdei2nlGpuEHFgLiqW1h0P50EfNOU,31534
122
+ rslearn-0.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
123
+ rslearn-0.0.4.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
124
+ rslearn-0.0.4.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
125
+ rslearn-0.0.4.dist-info/RECORD,,
@@ -1,262 +0,0 @@
1
- """Distributed training utilities for Soft MoE.
2
-
3
- This module provides utilities for distributed training of Soft MoE models,
4
- including all-gather operations and rank-based tensor splitting.
5
-
6
- Copied from
7
- https://raw.githubusercontent.com/lucidrains/soft-moe-pytorch/refs/heads/main/soft_moe_pytorch/distributed.py.
8
- """
9
-
10
- from typing import Any
11
-
12
- import torch
13
- import torch.distributed as dist
14
- import torch.nn.functional as F
15
- from einops import rearrange
16
- from torch import Tensor, nn
17
- from torch.autograd import Function
18
-
19
-
20
- def exists(val: Any) -> bool:
21
- """Check if a value exists (is not None).
22
-
23
- Args:
24
- val: The value to check.
25
-
26
- Returns:
27
- bool: True if the value is not None, False otherwise.
28
- """
29
- return val is not None
30
-
31
-
32
- def default(val: Any, d: Any) -> Any:
33
- """Return the value if it exists, otherwise return the default.
34
-
35
- Args:
36
- val: The value to check.
37
- d: The default value to return if val is None.
38
-
39
- Returns:
40
- Any: The value if it exists, otherwise the default.
41
- """
42
- return val if exists(val) else d
43
-
44
-
45
- def divisible_by(num: int, den: int) -> bool:
46
- """Check if a number is divisible by another.
47
-
48
- Args:
49
- num: The numerator.
50
- den: The denominator.
51
-
52
- Returns:
53
- bool: True if num is divisible by den, False otherwise.
54
- """
55
- return (num % den) == 0
56
-
57
-
58
- def pad_dim_to(t: Tensor, length: int, dim: int = 0) -> Tensor:
59
- """Pad a tensor along a specific dimension to a target length.
60
-
61
- Args:
62
- t: The input tensor.
63
- length: The target length to pad to.
64
- dim: The dimension to pad along.
65
-
66
- Returns:
67
- Tensor: The padded tensor with the specified dimension padded to length.
68
- """
69
- pad_length = length - t.shape[dim]
70
- zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
71
- return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
72
-
73
-
74
- def all_gather_same_dim(t: Tensor) -> list[Tensor]:
75
- """Gather tensors from all processes when they have the same dimension.
76
-
77
- Args:
78
- t: The tensor to gather from all processes.
79
-
80
- Returns:
81
- List[Tensor]: List of tensors gathered from all processes.
82
-
83
- Note:
84
- This function assumes all processes have tensors with the same shape.
85
- """
86
- world_size = dist.get_world_size()
87
- t = t.contiguous()
88
- gathered_tensors = [
89
- torch.empty_like(t, device=t.device, dtype=t.dtype) for i in range(world_size)
90
- ]
91
- dist.all_gather(gathered_tensors, t)
92
- return gathered_tensors
93
-
94
-
95
- def gather_sizes(t: Tensor, *, dim: int) -> Tensor:
96
- """Gather the sizes of tensors along a specific dimension from all processes.
97
-
98
- Args:
99
- t: The input tensor.
100
- dim: The dimension to gather sizes for.
101
-
102
- Returns:
103
- Tensor: Tensor containing the sizes from all processes.
104
- """
105
- size = torch.tensor(t.shape[dim], device=t.device, dtype=torch.long)
106
- sizes = all_gather_same_dim(size)
107
- return torch.stack(sizes)
108
-
109
-
110
- def has_only_one_value(t: Tensor) -> bool:
111
- """Check if all values in a tensor are the same.
112
-
113
- Args:
114
- t: The input tensor.
115
-
116
- Returns:
117
- bool: True if all values in the tensor are identical, False otherwise.
118
- """
119
- return (t == t[0]).all()
120
-
121
-
122
- def all_gather_variable_dim(
123
- t: Tensor, dim: int = 0, sizes: Tensor | None = None
124
- ) -> tuple[Tensor, Tensor]:
125
- """Gather tensors from all processes when they may have different dimensions.
126
-
127
- Args:
128
- t: The tensor to gather from all processes.
129
- dim: The dimension along which tensors may vary.
130
- sizes: Optional pre-computed sizes tensor. If None, will be computed.
131
-
132
- Returns:
133
- Tuple[Tensor, Tensor]: The gathered tensors and the sizes tensor.
134
-
135
- Note:
136
- This function handles the case where tensors from different processes
137
- may have different sizes along the specified dimension.
138
- """
139
- device = t.device
140
-
141
- if not exists(sizes):
142
- sizes = gather_sizes(t, dim=dim)
143
-
144
- if has_only_one_value(sizes):
145
- gathered_tensors = all_gather_same_dim(t)
146
- gathered_tensors = torch.cat(gathered_tensors, dim=dim)
147
- return gathered_tensors, sizes
148
-
149
- # Add null check for sizes
150
- if sizes is None:
151
- raise ValueError("sizes cannot be None")
152
-
153
- max_size = sizes.amax().item()
154
-
155
- padded_t = pad_dim_to(t, max_size, dim=dim)
156
- gathered_tensors = all_gather_same_dim(padded_t)
157
-
158
- gathered_tensors = torch.cat(gathered_tensors, dim=dim)
159
- seq = torch.arange(max_size, device=device)
160
-
161
- mask = rearrange(seq, "j -> 1 j") < rearrange(sizes, "i -> i 1")
162
- mask = rearrange(mask, "i j -> (i j)")
163
- seq = torch.arange(mask.shape[-1], device=device)
164
- indices = seq[mask]
165
-
166
- # Convert gathered_tensors to tensor before calling index_select
167
- if isinstance(gathered_tensors, list):
168
- gathered_tensors = torch.cat(gathered_tensors, dim=dim)
169
-
170
- gathered_tensors = gathered_tensors.index_select(dim, indices) # type: ignore
171
-
172
- return gathered_tensors, sizes
173
-
174
-
175
- class AllGatherFunction(Function):
176
- """Custom autograd function for all-gather operations.
177
-
178
- This function provides gradient support for all-gather operations
179
- by implementing custom forward and backward passes.
180
- """
181
-
182
- @staticmethod
183
- def forward(
184
- ctx: Any, x: Tensor, dim: int, sizes: Tensor | None
185
- ) -> tuple[Tensor, Tensor]:
186
- """Forward pass of the all-gather function.
187
-
188
- Args:
189
- ctx: The context object for storing information for backward pass.
190
- x: The input tensor to gather.
191
- dim: The dimension along which to gather.
192
- sizes: Optional pre-computed sizes tensor.
193
-
194
- Returns:
195
- Tuple[Tensor, Tensor]: The gathered tensor and the sizes tensor.
196
- """
197
- x, batch_sizes = all_gather_variable_dim(x, dim=dim, sizes=sizes)
198
- ctx.batch_sizes = batch_sizes.tolist()
199
- ctx.dim = dim
200
- return x, batch_sizes
201
-
202
- @staticmethod
203
- def backward(ctx: Any, grads: Tensor, _: Any) -> tuple[Tensor, None, None]:
204
- """Backward pass of the all-gather function.
205
-
206
- Args:
207
- ctx: The context object containing information from forward pass.
208
- grads: The gradient tensor.
209
- _: Unused parameter for compatibility.
210
-
211
- Returns:
212
- Tuple[Tensor, None, None]: The gradient for the input tensor and None for other inputs.
213
- """
214
- batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
215
- grads_by_rank = grads.split(batch_sizes, dim=ctx.dim)
216
- return grads_by_rank[rank], None, None
217
-
218
-
219
- class AllGather(nn.Module):
220
- """A module that performs all-gather operations across distributed processes.
221
-
222
- This module provides a convenient interface for gathering tensors from
223
- all processes in a distributed training setup.
224
- """
225
-
226
- def __init__(self, *, dim: int = 0) -> None:
227
- """Initialize the AllGather module.
228
-
229
- Args:
230
- dim: The dimension along which to gather tensors.
231
- """
232
- super().__init__()
233
- self.dim = dim
234
-
235
- def forward(self, x: Tensor, sizes: Tensor | None = None) -> tuple[Tensor, Tensor]:
236
- """Forward pass of the all-gather operation.
237
-
238
- Args:
239
- x: The input tensor to gather from all processes.
240
- sizes: Optional pre-computed sizes tensor.
241
-
242
- Returns:
243
- Tuple[Tensor, Tensor]: The gathered tensor and the sizes tensor.
244
- """
245
- return AllGatherFunction.apply(x, self.dim, sizes)
246
-
247
-
248
- def split_by_rank(x: list[Tensor]) -> Tensor:
249
- """Split a list of tensors and return the tensor corresponding to the current rank.
250
-
251
- Args:
252
- x: List of tensors, one per rank.
253
-
254
- Returns:
255
- Tensor: The tensor corresponding to the current process rank.
256
-
257
- Note:
258
- This function assumes the list has one tensor per rank and returns
259
- the tensor corresponding to the current process rank.
260
- """
261
- rank = dist.get_rank()
262
- return x[rank]