rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +59 -0
- rslearn/data_sources/copernicus.py +4 -4
- rslearn/data_sources/earthdaily.py +21 -1
- rslearn/data_sources/gcp_public_data.py +3 -3
- rslearn/data_sources/utils.py +1 -17
- rslearn/main.py +10 -1
- rslearn/models/trunk.py +0 -144
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +319 -0
- rslearn/train/callbacks/gradients.py +54 -34
- rslearn/train/data_module.py +70 -41
- rslearn/train/dataset.py +232 -54
- rslearn/train/lightning_module.py +4 -0
- rslearn/train/prediction_writer.py +7 -0
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/per_pixel_regression.py +259 -0
- rslearn/train/tasks/regression.py +6 -4
- rslearn/train/tasks/segmentation.py +44 -14
- rslearn/train/transforms/mask.py +69 -0
- rslearn/utils/geometry.py +8 -8
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/METADATA +3 -3
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/RECORD +26 -24
- rslearn/models/moe/distributed.py +0 -262
- rslearn/models/moe/soft.py +0 -676
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/WHEEL +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
rslearn/__init__.py,sha256=fFmAen3vxZyosEfPbG0W46IttujYGVxzrGkJ0YutmmY,73
|
|
2
|
+
rslearn/arg_parser.py,sha256=mkiZCiomUI5GNjG1jfPuTJebGHFzXbyUqe0pPwS4lTA,2055
|
|
2
3
|
rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
|
|
3
4
|
rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
|
|
4
|
-
rslearn/main.py,sha256=
|
|
5
|
+
rslearn/main.py,sha256=lAMcE4e3wCO2tVUq3bJl2oOHyztsTagtSNc0kJU7OZk,29266
|
|
5
6
|
rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
7
|
rslearn/config/__init__.py,sha256=Bhf2VVncdMYRC8Wfb4GsJJ13OAJYNCO_ODLSNTmBOHM,638
|
|
7
8
|
rslearn/config/dataset.py,sha256=cR6Jd9ppzHgKHCteUsNapCcsJk4k5X90E71EHfbW7m0,21046
|
|
@@ -10,11 +11,11 @@ rslearn/data_sources/aws_landsat.py,sha256=GA9H04KagBDm-N37jFdh_aHCX2ZneVdnqT1SN
|
|
|
10
11
|
rslearn/data_sources/aws_open_data.py,sha256=nU_D5cqc-wibxq4uyUNb0z-XD0Puf1gZ8v5FMiMAN5w,30258
|
|
11
12
|
rslearn/data_sources/aws_sentinel1.py,sha256=cmf_ZcB7GCyFAdbwExeAwJIHqLL0JVoXtq5WcQ8UuiU,5197
|
|
12
13
|
rslearn/data_sources/climate_data_store.py,sha256=Hct-0Ui-_CCQISOlzsqkK1dKz8684HRqfVUI-zXW2wA,11571
|
|
13
|
-
rslearn/data_sources/copernicus.py,sha256=
|
|
14
|
+
rslearn/data_sources/copernicus.py,sha256=BHPyeDLeCy1-Sjyhv84snW-TnZyNVnLyn4pjjZLfTzE,36652
|
|
14
15
|
rslearn/data_sources/data_source.py,sha256=69ptYhqa6pnKM04ux9hWvTPExN_lFNuU_0t_seYFnHE,3916
|
|
15
|
-
rslearn/data_sources/earthdaily.py,sha256=
|
|
16
|
+
rslearn/data_sources/earthdaily.py,sha256=dxOWm7Yiuh4fWVptRws_Ljh-HuNs1frf86ao91yS_80,19059
|
|
16
17
|
rslearn/data_sources/earthdata_srtm.py,sha256=ysyVbVDLjhhLKdh7WKhQcwZezqvmTYaiPetTborW6zQ,11166
|
|
17
|
-
rslearn/data_sources/gcp_public_data.py,sha256=
|
|
18
|
+
rslearn/data_sources/gcp_public_data.py,sha256=kr9stYo7ZCvz8s4E3wmoY-yAGZoLa_9RCwjS-Q5k9dM,36128
|
|
18
19
|
rslearn/data_sources/geotiff.py,sha256=sFUp919chaX4j6lQytNp__xnMLlDI3Ac3rfB6F8sgZ0,45
|
|
19
20
|
rslearn/data_sources/google_earth_engine.py,sha256=hpkt74ly2lEwjRrDp8FBmGvB3MEw_mQ38Av4rQOR3_w,24246
|
|
20
21
|
rslearn/data_sources/local_files.py,sha256=-XyydSPtui1m49YuP7YrNKjM5DBWMf7YgpWE9uRcvrM,18365
|
|
@@ -25,7 +26,7 @@ rslearn/data_sources/planetary_computer.py,sha256=Wchr-OmAffuVteUW6VRofIqFpE-cJq
|
|
|
25
26
|
rslearn/data_sources/raster_source.py,sha256=b8wo55GhVLxXwx1WYLzeRAlzD_ZkE_P9tnvUOdnsfQE,689
|
|
26
27
|
rslearn/data_sources/usda_cdl.py,sha256=2_V11AhPRgLEGd4U5Pmx3UvE2HWBPbsFXhUIQVRVFeE,7138
|
|
27
28
|
rslearn/data_sources/usgs_landsat.py,sha256=31GmOUfmxwTE6MTiVI4psb-ciVmunuA8cfvqDuvTHPE,19312
|
|
28
|
-
rslearn/data_sources/utils.py,sha256=
|
|
29
|
+
rslearn/data_sources/utils.py,sha256=oi2ybE423TLgpXlNjZ5qDQxDiwbSs7b-qD71UueQZHE,11327
|
|
29
30
|
rslearn/data_sources/vector_source.py,sha256=NCa7CxIrGKe9yRT0NyyFKFQboDGDZ1h7663PV9OfMOM,44
|
|
30
31
|
rslearn/data_sources/worldcereal.py,sha256=Psdf3EF3REj1WDltHWyMaICY3--KAJO_nEqpF0Gl_G8,21808
|
|
31
32
|
rslearn/data_sources/worldcover.py,sha256=rimHJpQN9a56GaYxyHTOGXKzE3bkWKgd1UbH5A4aaGs,6097
|
|
@@ -59,7 +60,7 @@ rslearn/models/ssl4eo_s12.py,sha256=sOGEHcDo-rNdmEyoLu2AVEqfxRM_cv6zpfAmyn5c6tw,
|
|
|
59
60
|
rslearn/models/swin.py,sha256=bMlGePXMFou4A_YSUZzjHgN9NniGXaCWdGQ31xHDKis,5511
|
|
60
61
|
rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
|
|
61
62
|
rslearn/models/terramind.py,sha256=kipar8sMaHJJ3b8vIgL0-s4qhHcA0Vb854vmlZ9cWh4,7524
|
|
62
|
-
rslearn/models/trunk.py,sha256=
|
|
63
|
+
rslearn/models/trunk.py,sha256=H1QPQGAKsmocq3OiF66GW8MQI4LffupTDrgZR4Ta7QM,4708
|
|
63
64
|
rslearn/models/unet.py,sha256=0xoKSsfG7y71lOqlx1F2G1H-4qq_ChjAuaAhNlTWIeo,5793
|
|
64
65
|
rslearn/models/upsample.py,sha256=A0ppAFvoqSMMvESE5vxvA8giY8cToD8QoeMMPGk2tUg,965
|
|
65
66
|
rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
|
|
@@ -70,33 +71,34 @@ rslearn/models/detr/matcher.py,sha256=4h_xFlgTMEJvJ6aLZUamrKZ72L5hDk9wPglNZ81JBg
|
|
|
70
71
|
rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJTC608TaX6yY,3869
|
|
71
72
|
rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
|
|
72
73
|
rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
|
|
73
|
-
rslearn/models/moe/distributed.py,sha256=CFlL8eC6I4LZynz5ydcs1Xy7BuLVuEfcqHupoiKLRQ0,7948
|
|
74
|
-
rslearn/models/moe/soft.py,sha256=PJgifOWBvb-ltA0NJghsOtl3fDoixOz08ZQLqtEdndU,21411
|
|
75
74
|
rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_Qk,2050
|
|
76
75
|
rslearn/tile_stores/default.py,sha256=DEZmji2iLEVgI3abXwne6tb4C1qVtx_CaLxTpswfUV4,13852
|
|
77
76
|
rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
|
|
78
77
|
rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
|
|
79
|
-
rslearn/train/data_module.py,sha256=
|
|
80
|
-
rslearn/train/dataset.py,sha256=
|
|
81
|
-
rslearn/train/lightning_module.py,sha256=
|
|
78
|
+
rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
|
|
79
|
+
rslearn/train/dataset.py,sha256=YiskNlYYcKqZxyw0Xzop1RGLbjMc-oK_rmhrSMVbTQg,51857
|
|
80
|
+
rslearn/train/lightning_module.py,sha256=ge2z8trU7cMvxBeqUXC1tB44pftzitw7DRsIa6asBS4,14623
|
|
82
81
|
rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
|
|
83
|
-
rslearn/train/prediction_writer.py,sha256=
|
|
84
|
-
rslearn/train/scheduler.py,sha256=
|
|
82
|
+
rslearn/train/prediction_writer.py,sha256=YNs92QqPrqbREZXoE-aPa_oKQW0C9LvZAY129vyvI08,13288
|
|
83
|
+
rslearn/train/scheduler.py,sha256=wFbmycMHgL6nRYeYalDjb0G8YVo8VD3T3sABS61jJ7c,2318
|
|
85
84
|
rslearn/train/callbacks/__init__.py,sha256=VNV0ArZyYMvl3dGK2wl6F046khYJ1dEBlJS6G_SYNm0,47
|
|
86
|
-
rslearn/train/callbacks/
|
|
87
|
-
rslearn/train/callbacks/
|
|
85
|
+
rslearn/train/callbacks/adapters.py,sha256=yfv8nyCj3jmo2_dNkFrjukKxh0MHsf2xKqWwMF0QUtY,1869
|
|
86
|
+
rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjTZZEPBQaSul6Zc,17418
|
|
87
|
+
rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
|
|
88
88
|
rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
|
|
89
89
|
rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
|
|
90
90
|
rslearn/train/tasks/classification.py,sha256=DI0_Wzs-9rNPWokvfxi1BIA6QyqNee42SpptQx82WHM,13182
|
|
91
91
|
rslearn/train/tasks/detection.py,sha256=OoZzC8ZbmhyZ30tD-4cB-3Jj0AN6Y7hg0wk27rDguCE,22297
|
|
92
92
|
rslearn/train/tasks/multi_task.py,sha256=dBWsnbvQ0CReNsbDHmZ_-sXjUE0H4S2OPcbJwMquG9g,6016
|
|
93
|
-
rslearn/train/tasks/
|
|
94
|
-
rslearn/train/tasks/
|
|
93
|
+
rslearn/train/tasks/per_pixel_regression.py,sha256=tkVntKFzPlWFxdupPlMfhIRWlJ0UCgxg_FGhcA2-wjE,8649
|
|
94
|
+
rslearn/train/tasks/regression.py,sha256=_PoxOfWNseujD4IWsuTL82fAAXgtco4WdfkNXQ68Nbg,11497
|
|
95
|
+
rslearn/train/tasks/segmentation.py,sha256=xEni3CLDyetviv84XrpJg5xeJU87WHGFKTVfIeemGIY,21868
|
|
95
96
|
rslearn/train/tasks/task.py,sha256=4w2xKL_U5JAtdj2dYoVv82h6xTtgUsA3IvIOcXyZecs,3887
|
|
96
97
|
rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
|
|
97
98
|
rslearn/train/transforms/concatenate.py,sha256=sdVLJIyr9Nj2tzXEzvWFQnjJjyRSuhR_Faf6UlMIvbg,1568
|
|
98
99
|
rslearn/train/transforms/crop.py,sha256=4jA3JJsC0ghicPHbfsNJ0d3WpChyvftY73ONiwQaif0,4214
|
|
99
100
|
rslearn/train/transforms/flip.py,sha256=lkTeje3T8gNn2gt6957morXq1fGNho-apSpCvNp0_9o,3480
|
|
101
|
+
rslearn/train/transforms/mask.py,sha256=pwt33XXWLwldLiar-PgVgBQzQd1qfL18SPz3LYQMoYM,2111
|
|
100
102
|
rslearn/train/transforms/normalize.py,sha256=zYLqcDQcrjukjf5XrbFmS990PK1WQMSmHqQZKa_T040,3514
|
|
101
103
|
rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
|
|
102
104
|
rslearn/train/transforms/transform.py,sha256=8Q-dPrmDr0tJ9ZOwjWqWK8kbnKi4uLxEnS9Nwf6BVJk,3594
|
|
@@ -104,7 +106,7 @@ rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
|
|
|
104
106
|
rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
|
|
105
107
|
rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
|
|
106
108
|
rslearn/utils/fsspec.py,sha256=9QwN46heBhjUnth3qFeRNE3W6Wlr6dM3twYVswPnS9o,5300
|
|
107
|
-
rslearn/utils/geometry.py,sha256=
|
|
109
|
+
rslearn/utils/geometry.py,sha256=kE4RP1g2QLcbUVF329CUSBUYHdPbpnnMuzSqLMOXvLQ,15955
|
|
108
110
|
rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF9Ps,3599
|
|
109
111
|
rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
|
|
110
112
|
rslearn/utils/jsonargparse.py,sha256=JcTKQoZ6jgwag-kSeTIEVBO9AsRj0X1oEJBsoaCazH4,658
|
|
@@ -115,9 +117,9 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
|
|
|
115
117
|
rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
|
|
116
118
|
rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
|
|
117
119
|
rslearn/utils/vector_format.py,sha256=XggLCIUQBZWhOXWjvhrxBOHULpsXCbktm804DSAAink,15167
|
|
118
|
-
rslearn-0.0.
|
|
119
|
-
rslearn-0.0.
|
|
120
|
-
rslearn-0.0.
|
|
121
|
-
rslearn-0.0.
|
|
122
|
-
rslearn-0.0.
|
|
123
|
-
rslearn-0.0.
|
|
120
|
+
rslearn-0.0.4.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
121
|
+
rslearn-0.0.4.dist-info/METADATA,sha256=_UtS9N1YTE0lKlPdei2nlGpuEHFgLiqW1h0P50EfNOU,31534
|
|
122
|
+
rslearn-0.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
123
|
+
rslearn-0.0.4.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
124
|
+
rslearn-0.0.4.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
125
|
+
rslearn-0.0.4.dist-info/RECORD,,
|
|
@@ -1,262 +0,0 @@
|
|
|
1
|
-
"""Distributed training utilities for Soft MoE.
|
|
2
|
-
|
|
3
|
-
This module provides utilities for distributed training of Soft MoE models,
|
|
4
|
-
including all-gather operations and rank-based tensor splitting.
|
|
5
|
-
|
|
6
|
-
Copied from
|
|
7
|
-
https://raw.githubusercontent.com/lucidrains/soft-moe-pytorch/refs/heads/main/soft_moe_pytorch/distributed.py.
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
from typing import Any
|
|
11
|
-
|
|
12
|
-
import torch
|
|
13
|
-
import torch.distributed as dist
|
|
14
|
-
import torch.nn.functional as F
|
|
15
|
-
from einops import rearrange
|
|
16
|
-
from torch import Tensor, nn
|
|
17
|
-
from torch.autograd import Function
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def exists(val: Any) -> bool:
|
|
21
|
-
"""Check if a value exists (is not None).
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
val: The value to check.
|
|
25
|
-
|
|
26
|
-
Returns:
|
|
27
|
-
bool: True if the value is not None, False otherwise.
|
|
28
|
-
"""
|
|
29
|
-
return val is not None
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def default(val: Any, d: Any) -> Any:
|
|
33
|
-
"""Return the value if it exists, otherwise return the default.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
val: The value to check.
|
|
37
|
-
d: The default value to return if val is None.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
Any: The value if it exists, otherwise the default.
|
|
41
|
-
"""
|
|
42
|
-
return val if exists(val) else d
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def divisible_by(num: int, den: int) -> bool:
|
|
46
|
-
"""Check if a number is divisible by another.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
num: The numerator.
|
|
50
|
-
den: The denominator.
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
bool: True if num is divisible by den, False otherwise.
|
|
54
|
-
"""
|
|
55
|
-
return (num % den) == 0
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def pad_dim_to(t: Tensor, length: int, dim: int = 0) -> Tensor:
|
|
59
|
-
"""Pad a tensor along a specific dimension to a target length.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
t: The input tensor.
|
|
63
|
-
length: The target length to pad to.
|
|
64
|
-
dim: The dimension to pad along.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
Tensor: The padded tensor with the specified dimension padded to length.
|
|
68
|
-
"""
|
|
69
|
-
pad_length = length - t.shape[dim]
|
|
70
|
-
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
71
|
-
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def all_gather_same_dim(t: Tensor) -> list[Tensor]:
|
|
75
|
-
"""Gather tensors from all processes when they have the same dimension.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
t: The tensor to gather from all processes.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
List[Tensor]: List of tensors gathered from all processes.
|
|
82
|
-
|
|
83
|
-
Note:
|
|
84
|
-
This function assumes all processes have tensors with the same shape.
|
|
85
|
-
"""
|
|
86
|
-
world_size = dist.get_world_size()
|
|
87
|
-
t = t.contiguous()
|
|
88
|
-
gathered_tensors = [
|
|
89
|
-
torch.empty_like(t, device=t.device, dtype=t.dtype) for i in range(world_size)
|
|
90
|
-
]
|
|
91
|
-
dist.all_gather(gathered_tensors, t)
|
|
92
|
-
return gathered_tensors
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def gather_sizes(t: Tensor, *, dim: int) -> Tensor:
|
|
96
|
-
"""Gather the sizes of tensors along a specific dimension from all processes.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
t: The input tensor.
|
|
100
|
-
dim: The dimension to gather sizes for.
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
Tensor: Tensor containing the sizes from all processes.
|
|
104
|
-
"""
|
|
105
|
-
size = torch.tensor(t.shape[dim], device=t.device, dtype=torch.long)
|
|
106
|
-
sizes = all_gather_same_dim(size)
|
|
107
|
-
return torch.stack(sizes)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def has_only_one_value(t: Tensor) -> bool:
|
|
111
|
-
"""Check if all values in a tensor are the same.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
t: The input tensor.
|
|
115
|
-
|
|
116
|
-
Returns:
|
|
117
|
-
bool: True if all values in the tensor are identical, False otherwise.
|
|
118
|
-
"""
|
|
119
|
-
return (t == t[0]).all()
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def all_gather_variable_dim(
|
|
123
|
-
t: Tensor, dim: int = 0, sizes: Tensor | None = None
|
|
124
|
-
) -> tuple[Tensor, Tensor]:
|
|
125
|
-
"""Gather tensors from all processes when they may have different dimensions.
|
|
126
|
-
|
|
127
|
-
Args:
|
|
128
|
-
t: The tensor to gather from all processes.
|
|
129
|
-
dim: The dimension along which tensors may vary.
|
|
130
|
-
sizes: Optional pre-computed sizes tensor. If None, will be computed.
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
Tuple[Tensor, Tensor]: The gathered tensors and the sizes tensor.
|
|
134
|
-
|
|
135
|
-
Note:
|
|
136
|
-
This function handles the case where tensors from different processes
|
|
137
|
-
may have different sizes along the specified dimension.
|
|
138
|
-
"""
|
|
139
|
-
device = t.device
|
|
140
|
-
|
|
141
|
-
if not exists(sizes):
|
|
142
|
-
sizes = gather_sizes(t, dim=dim)
|
|
143
|
-
|
|
144
|
-
if has_only_one_value(sizes):
|
|
145
|
-
gathered_tensors = all_gather_same_dim(t)
|
|
146
|
-
gathered_tensors = torch.cat(gathered_tensors, dim=dim)
|
|
147
|
-
return gathered_tensors, sizes
|
|
148
|
-
|
|
149
|
-
# Add null check for sizes
|
|
150
|
-
if sizes is None:
|
|
151
|
-
raise ValueError("sizes cannot be None")
|
|
152
|
-
|
|
153
|
-
max_size = sizes.amax().item()
|
|
154
|
-
|
|
155
|
-
padded_t = pad_dim_to(t, max_size, dim=dim)
|
|
156
|
-
gathered_tensors = all_gather_same_dim(padded_t)
|
|
157
|
-
|
|
158
|
-
gathered_tensors = torch.cat(gathered_tensors, dim=dim)
|
|
159
|
-
seq = torch.arange(max_size, device=device)
|
|
160
|
-
|
|
161
|
-
mask = rearrange(seq, "j -> 1 j") < rearrange(sizes, "i -> i 1")
|
|
162
|
-
mask = rearrange(mask, "i j -> (i j)")
|
|
163
|
-
seq = torch.arange(mask.shape[-1], device=device)
|
|
164
|
-
indices = seq[mask]
|
|
165
|
-
|
|
166
|
-
# Convert gathered_tensors to tensor before calling index_select
|
|
167
|
-
if isinstance(gathered_tensors, list):
|
|
168
|
-
gathered_tensors = torch.cat(gathered_tensors, dim=dim)
|
|
169
|
-
|
|
170
|
-
gathered_tensors = gathered_tensors.index_select(dim, indices) # type: ignore
|
|
171
|
-
|
|
172
|
-
return gathered_tensors, sizes
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
class AllGatherFunction(Function):
|
|
176
|
-
"""Custom autograd function for all-gather operations.
|
|
177
|
-
|
|
178
|
-
This function provides gradient support for all-gather operations
|
|
179
|
-
by implementing custom forward and backward passes.
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
@staticmethod
|
|
183
|
-
def forward(
|
|
184
|
-
ctx: Any, x: Tensor, dim: int, sizes: Tensor | None
|
|
185
|
-
) -> tuple[Tensor, Tensor]:
|
|
186
|
-
"""Forward pass of the all-gather function.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
ctx: The context object for storing information for backward pass.
|
|
190
|
-
x: The input tensor to gather.
|
|
191
|
-
dim: The dimension along which to gather.
|
|
192
|
-
sizes: Optional pre-computed sizes tensor.
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
Tuple[Tensor, Tensor]: The gathered tensor and the sizes tensor.
|
|
196
|
-
"""
|
|
197
|
-
x, batch_sizes = all_gather_variable_dim(x, dim=dim, sizes=sizes)
|
|
198
|
-
ctx.batch_sizes = batch_sizes.tolist()
|
|
199
|
-
ctx.dim = dim
|
|
200
|
-
return x, batch_sizes
|
|
201
|
-
|
|
202
|
-
@staticmethod
|
|
203
|
-
def backward(ctx: Any, grads: Tensor, _: Any) -> tuple[Tensor, None, None]:
|
|
204
|
-
"""Backward pass of the all-gather function.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
ctx: The context object containing information from forward pass.
|
|
208
|
-
grads: The gradient tensor.
|
|
209
|
-
_: Unused parameter for compatibility.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
Tuple[Tensor, None, None]: The gradient for the input tensor and None for other inputs.
|
|
213
|
-
"""
|
|
214
|
-
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
|
|
215
|
-
grads_by_rank = grads.split(batch_sizes, dim=ctx.dim)
|
|
216
|
-
return grads_by_rank[rank], None, None
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
class AllGather(nn.Module):
|
|
220
|
-
"""A module that performs all-gather operations across distributed processes.
|
|
221
|
-
|
|
222
|
-
This module provides a convenient interface for gathering tensors from
|
|
223
|
-
all processes in a distributed training setup.
|
|
224
|
-
"""
|
|
225
|
-
|
|
226
|
-
def __init__(self, *, dim: int = 0) -> None:
|
|
227
|
-
"""Initialize the AllGather module.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
dim: The dimension along which to gather tensors.
|
|
231
|
-
"""
|
|
232
|
-
super().__init__()
|
|
233
|
-
self.dim = dim
|
|
234
|
-
|
|
235
|
-
def forward(self, x: Tensor, sizes: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
|
236
|
-
"""Forward pass of the all-gather operation.
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
x: The input tensor to gather from all processes.
|
|
240
|
-
sizes: Optional pre-computed sizes tensor.
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
Tuple[Tensor, Tensor]: The gathered tensor and the sizes tensor.
|
|
244
|
-
"""
|
|
245
|
-
return AllGatherFunction.apply(x, self.dim, sizes)
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
def split_by_rank(x: list[Tensor]) -> Tensor:
|
|
249
|
-
"""Split a list of tensors and return the tensor corresponding to the current rank.
|
|
250
|
-
|
|
251
|
-
Args:
|
|
252
|
-
x: List of tensors, one per rank.
|
|
253
|
-
|
|
254
|
-
Returns:
|
|
255
|
-
Tensor: The tensor corresponding to the current process rank.
|
|
256
|
-
|
|
257
|
-
Note:
|
|
258
|
-
This function assumes the list has one tensor per rank and returns
|
|
259
|
-
the tensor corresponding to the current process rank.
|
|
260
|
-
"""
|
|
261
|
-
rank = dist.get_rank()
|
|
262
|
-
return x[rank]
|