adv-optm 2.4.dev11__tar.gz → 2.4.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/PKG-INFO +1 -1
  2. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/__init__.py +3 -3
  3. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/AdaMuon_adv.py +1 -1
  4. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/AdamW_adv.py +16 -6
  5. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Adopt_adv.py +1 -1
  6. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Muon_adv.py +1 -1
  7. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/SignSGD_adv.py +1 -1
  8. adv_optm-2.4.dev11/adv_optm/optim/SGD_adv.py → adv_optm-2.4.dev13/adv_optm/optim/SinkSGD_adv.py +31 -22
  9. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/__init__.py +2 -2
  10. adv_optm-2.4.dev13/adv_optm/util/sinkhorn.py +77 -0
  11. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/PKG-INFO +1 -1
  12. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/SOURCES.txt +1 -1
  13. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/setup.py +1 -1
  14. adv_optm-2.4.dev11/adv_optm/util/sinkhorn.py +0 -42
  15. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/LICENSE +0 -0
  16. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/README.md +0 -0
  17. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  18. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Lion_adv.py +0 -0
  19. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Prodigy_adv.py +0 -0
  20. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  21. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Kourkoutas.py +0 -0
  22. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Muon_AuxAdam.py +0 -0
  23. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/Muon_util.py +0 -0
  24. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/OrthoGrad.py +0 -0
  25. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/__init__.py +0 -0
  26. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/centered_decay.py +0 -0
  27. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/factorization_util.py +0 -0
  28. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/lion_k.py +0 -0
  29. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/param_update.py +0 -0
  30. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/scaled_optm.py +0 -0
  31. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/signed_util.py +0 -0
  32. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/state_util.py +0 -0
  33. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm/util/update_util.py +0 -0
  34. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/dependency_links.txt +0 -0
  35. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/requires.txt +0 -0
  36. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/adv_optm.egg-info/top_level.txt +0 -0
  37. {adv_optm-2.4.dev11 → adv_optm-2.4.dev13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev11
3
+ Version: 2.4.dev13
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -8,7 +8,7 @@ from .optim import (
8
8
  Muon_adv,
9
9
  AdaMuon_adv,
10
10
  SignSGD_adv,
11
- SGD_adv,
11
+ SinkSGD_adv,
12
12
  )
13
13
 
14
14
  __all__ = [
@@ -21,7 +21,7 @@ __all__ = [
21
21
  "Muon_adv",
22
22
  "AdaMuon_adv",
23
23
  "SignSGD_adv",
24
- "SGD_adv",
24
+ "SinkSGD_adv",
25
25
  ]
26
26
 
27
- __version__ = "2.4.dev11"
27
+ __version__ = "2.4.dev13"
@@ -468,7 +468,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
468
468
  if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
469
469
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
470
470
  elif actual_precision == 'int8_sr':
471
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
471
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
472
472
  elif actual_precision == 'fp8_sr':
473
473
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
474
474
  else:
@@ -232,12 +232,13 @@ class AdamW_adv(torch.optim.Optimizer):
232
232
  def supports_flat_params(self):
233
233
  return False
234
234
 
235
- @torch.no_grad()
236
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
237
- if p.grad is None:
238
- return
235
+ def init_step(self):
236
+ for group in self.param_groups:
237
+ for i, p in enumerate(group['params']):
238
+ self.__init_state(p, group)
239
239
 
240
- grad = p.grad
240
+ @torch.no_grad()
241
+ def __init_state(self, p, group):
241
242
  state = self.state[p]
242
243
 
243
244
  # State Initialization
@@ -303,6 +304,15 @@ class AdamW_adv(torch.optim.Optimizer):
303
304
 
304
305
  _init_fisher_wd_scaler(group, state, p)
305
306
 
307
+ @torch.no_grad()
308
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
309
+ if p.grad is None:
310
+ return
311
+
312
+ grad = p.grad
313
+ state = self.state[p]
314
+ self.__init_state(p, group)
315
+
306
316
  beta1, beta2 = group['betas']
307
317
 
308
318
  current_step = state['step']
@@ -333,7 +343,7 @@ class AdamW_adv(torch.optim.Optimizer):
333
343
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
334
344
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
335
345
  elif group['actual_state_precision'] == 'int8_sr':
336
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
346
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
337
347
  elif group['actual_state_precision'] == 'fp8_sr':
338
348
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
339
349
  step_param_fn = self._compiled_step_parameter
@@ -359,7 +359,7 @@ class Adopt_adv(torch.optim.Optimizer):
359
359
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
360
360
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
361
361
  elif group['actual_state_precision'] == 'int8_sr':
362
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
362
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
363
363
  elif group['actual_state_precision'] == 'fp8_sr':
364
364
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
365
365
  step_param_fn = self._compiled_step_parameter
@@ -422,7 +422,7 @@ class Muon_adv(torch.optim.Optimizer):
422
422
  if actual_precision == 'bf16_sr' and random_int_state_tensor is not None:
423
423
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
424
424
  elif actual_precision == 'int8_sr':
425
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
425
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
426
426
  elif actual_precision == 'fp8_sr':
427
427
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
428
428
  else:
@@ -252,7 +252,7 @@ class SignSGD_adv(torch.optim.Optimizer):
252
252
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
253
253
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
254
254
  elif group['actual_state_precision'] == 'int8_sr':
255
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
255
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
256
256
  elif group['actual_state_precision'] == 'fp8_sr':
257
257
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
258
258
 
@@ -11,12 +11,11 @@ from ..util.centered_decay import _init_anchor
11
11
  from ..util.state_util import init_state_tensor, get_state, set_state, upcast_grad_for_precision
12
12
  from ..util.sinkhorn import apply_sr_sinkhorn
13
13
 
14
- class SGD_adv(torch.optim.Optimizer):
14
+ class SinkSGD_adv(torch.optim.Optimizer):
15
15
  """
16
- Implements an advanced Stochastic Gradient Descent (SGD) algorithm.
17
- This is an advanced version of SGD with optional features like
18
- low-rank factorization of optimizer states (SMMF), OrthoGrad,
19
- Cautious updating, and AdEMAMix extensions.
16
+ Implements an advanced Stochastic Gradient Descent (SGD) with Sinkhorn Iterative Normalization (SinkSGD) algorithm.
17
+ This is an advanced version of SinkSGD with optional features like
18
+ low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
20
19
 
21
20
  Args:
22
21
  params (iterable): iterable of parameters to optimize or dicts defining
@@ -62,11 +61,11 @@ class SGD_adv(torch.optim.Optimizer):
62
61
  cautious_wd: bool = False,
63
62
  # Stochastic Rounding for BF16
64
63
  stochastic_rounding: bool = True,
65
- # OrthoGrad
66
- orthogonal_gradient: bool = False,
67
64
  # Sinkhorn Iterative Normalization
68
- sinkhorn: bool = False,
69
65
  sinkhorn_iterations: int = 5,
66
+ orthogonal_sinkhorn: bool = False,
67
+ # OrthoGrad
68
+ orthogonal_gradient: bool = False,
70
69
  # Spectral Normed Optimizer
71
70
  spectral_normalization: bool = False,
72
71
  # Centered WD
@@ -101,7 +100,8 @@ class SGD_adv(torch.optim.Optimizer):
101
100
  "decoupled_wd": decoupled_wd, "cautious_wd": cautious_wd,
102
101
  "orthogonal_gradient": orthogonal_gradient,
103
102
  "compiled_optimizer": compiled_optimizer,
104
- "sinkhorn": sinkhorn, "sinkhorn_iterations": sinkhorn_iterations,
103
+ "sinkhorn_iterations": sinkhorn_iterations,
104
+ "orthogonal_sinkhorn": orthogonal_sinkhorn,
105
105
  "spectral_normalization": spectral_normalization,
106
106
  "centered_wd": centered_wd, "centered_wd_mode": centered_wd_mode,
107
107
  "state_precision": state_precision,
@@ -116,6 +116,8 @@ class SGD_adv(torch.optim.Optimizer):
116
116
  for device in devices:
117
117
  param_update.set_seed(device)
118
118
 
119
+ self.init_step()
120
+
119
121
  self._compiled_step_parameter = None
120
122
  if compiled_optimizer:
121
123
  self.compile(fullgraph=True)
@@ -136,14 +138,14 @@ class SGD_adv(torch.optim.Optimizer):
136
138
  def supports_flat_params(self):
137
139
  return False
138
140
 
139
- @torch.no_grad()
140
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
141
- if p.grad is None:
142
- return
141
+ def init_step(self):
142
+ for group in self.param_groups:
143
+ for i, p in enumerate(group['params']):
144
+ self.__init_state(p, group)
143
145
 
144
- grad = p.grad
146
+ @torch.no_grad()
147
+ def __init_state(self, p, group):
145
148
  state = self.state[p]
146
-
147
149
  # State Initialization
148
150
  if 'step' not in state:
149
151
  state['step'] = 0
@@ -180,6 +182,15 @@ class SGD_adv(torch.optim.Optimizer):
180
182
 
181
183
  _init_anchor(p, state, group)
182
184
 
185
+ @torch.no_grad()
186
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
187
+ if p.grad is None:
188
+ return
189
+
190
+ grad = p.grad
191
+ state = self.state[p]
192
+ self.__init_state(p, group)
193
+
183
194
  step_size = group['lr']
184
195
 
185
196
  random_int_tensor = None
@@ -193,7 +204,7 @@ class SGD_adv(torch.optim.Optimizer):
193
204
  if group['actual_state_precision'] == 'bf16_sr' and random_int_state_tensor is None:
194
205
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
195
206
  elif group['actual_state_precision'] == 'int8_sr':
196
- random_int_state_tensor = param_update._get_random_int_for_int8_sr(p)
207
+ random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
197
208
  elif group['actual_state_precision'] == 'fp8_sr':
198
209
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
199
210
  step_param_fn = self._compiled_step_parameter
@@ -219,7 +230,7 @@ class SGD_adv(torch.optim.Optimizer):
219
230
 
220
231
  if momentum != 0:
221
232
  buf = _reconstruct_state((state['mu_b_nmf'], state['mv_b_nmf'], state['sign'], d2), signed=True)
222
- buf.mul_(momentum).add_(grad_reshaped, alpha=1 - momentum)
233
+ buf.lerp_(grad_reshaped, 1 - momentum)
223
234
 
224
235
  # Factorize updated buffer
225
236
  state['mu_b_nmf'], state['mv_b_nmf'], state['sign'] = _factorize_state(buf.clone(), signed=True)
@@ -239,9 +250,7 @@ class SGD_adv(torch.optim.Optimizer):
239
250
 
240
251
  if momentum != 0:
241
252
  buf = get_state(state, 'momentum_buffer', actual_precision)
242
-
243
- buf.mul_(momentum).add_(grad, alpha=1 - momentum)
244
-
253
+ buf.lerp_(grad, 1 - momentum)
245
254
 
246
255
  set_state(state, 'momentum_buffer', buf, actual_precision, random_int_state_tensor)
247
256
 
@@ -254,8 +263,8 @@ class SGD_adv(torch.optim.Optimizer):
254
263
 
255
264
  del random_int_state_tensor
256
265
 
257
- if group['sinkhorn']:
258
- update = apply_sr_sinkhorn(update, iters=group['sinkhorn_iterations'])
266
+ # Sinkhorn iterative normalization
267
+ update = apply_sr_sinkhorn(update, p, ortho_project=group['orthogonal_sinkhorn'], iters=group['sinkhorn_iterations'])
259
268
 
260
269
  update_scaling = step_size
261
270
  if group.get('spectral_normalization', False):
@@ -7,7 +7,7 @@ from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
7
  from .Muon_adv import Muon_adv
8
8
  from .AdaMuon_adv import AdaMuon_adv
9
9
  from .SignSGD_adv import SignSGD_adv
10
- from .SGD_adv import SGD_adv
10
+ from .SinkSGD_adv import SinkSGD_adv
11
11
 
12
12
  __all__ = [
13
13
  "AdamW_adv",
@@ -19,5 +19,5 @@ __all__ = [
19
19
  "Muon_adv",
20
20
  "AdaMuon_adv",
21
21
  "SignSGD_adv",
22
- "SGD_adv",
22
+ "SinkSGD_adv",
23
23
  ]
@@ -0,0 +1,77 @@
1
+ import math
2
+ import torch
3
+
4
+ def apply_sr_sinkhorn(update: torch.Tensor, p: torch.Tensor, ortho_project: bool, iters: int = 5) -> torch.Tensor:
5
+ """
6
+ Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
+ As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
8
+
9
+ This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
10
+ and column-wise L2 norm, driving it toward a fixed point that uniformly
11
+ distributes update magnitudes.
12
+ """
13
+ original_shape = update.shape
14
+ original_dtype = update.dtype
15
+ update = update.float()
16
+
17
+ # 1D Vector Case
18
+ if update.dim() == 1:
19
+ if ortho_project:
20
+ p_float = p.float()
21
+ p_norm_sq = torch.dot(p_float, p_float).add_(1e-30)
22
+ proj = torch.dot(p_float, update) / p_norm_sq
23
+ update.sub_(p_float * proj)
24
+ norm = update.norm(p=2).clamp_min_(1e-12)
25
+ return update.mul_(math.sqrt(update.numel()) / norm).view(original_shape).to(original_dtype)
26
+
27
+ # 2D+ Matrix Case
28
+ update_2d = update.view(update.shape[0], -1)
29
+
30
+ m, n = update_2d.shape
31
+
32
+ # Dynamically determine the order of normalization based on aspect ratio
33
+ # Normalizing the longer dimension first aids stability.
34
+ scale_cond = update_2d.shape[0] > update_2d.shape[1]
35
+ dim = 0 if scale_cond else 1
36
+
37
+
38
+ # Precompute scaling factors.
39
+ scale_first = math.sqrt(m if scale_cond else n)
40
+ scale_second = math.sqrt(n if scale_cond else m)
41
+
42
+ if ortho_project:
43
+ param_2d = p.float().view(p.shape[0], -1)
44
+ p_norm_sq_dim = torch.sum(param_2d * param_2d, dim=dim, keepdim=True).add_(1e-30)
45
+ p_norm_sq_adim = torch.sum(param_2d * param_2d, dim=1-dim, keepdim=True).add_(1e-30)
46
+
47
+ # In-place alternating Sinkhorn normalization steps
48
+ for _ in range(iters):
49
+ # First normalization step
50
+ norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
51
+ update_2d.mul_(scale_first / norm1)
52
+ if ortho_project:
53
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_dim, dim, scale_first)
54
+
55
+ # Second normalization step
56
+ norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
57
+ update_2d.mul_(scale_second / norm2)
58
+ if ortho_project:
59
+ update_2d = ortho_normed(param_2d, update_2d, p_norm_sq_adim, 1-dim, scale_second)
60
+
61
+ return update_2d.view(original_shape).to(original_dtype)
62
+
63
+ def ortho_normed(p_2d, update_2d, p_norm_sq, dim, target_norm):
64
+ """
65
+ Projects the update to be orthogonal to p along 'dim' and restores the original norm.
66
+ """
67
+ # Project: g_orth = g - (p * <p, g> / ||p||^2)
68
+ dot_prod = torch.sum(p_2d * update_2d, dim=dim, keepdim=True)
69
+ proj = dot_prod / p_norm_sq
70
+
71
+ # In-place subtraction: update_2d = update_2d - (proj * p_2d)
72
+ update_2d.addcmul_(proj, p_2d, value=-1.0)
73
+
74
+ # Magnitude Preservation
75
+ g_orth_norm = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
76
+ scale_factor = target_norm / g_orth_norm
77
+ return update_2d.mul_(scale_factor)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev11
3
+ Version: 2.4.dev13
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -14,9 +14,9 @@ adv_optm/optim/Lion_Prodigy_adv.py
14
14
  adv_optm/optim/Lion_adv.py
15
15
  adv_optm/optim/Muon_adv.py
16
16
  adv_optm/optim/Prodigy_adv.py
17
- adv_optm/optim/SGD_adv.py
18
17
  adv_optm/optim/SignSGD_adv.py
19
18
  adv_optm/optim/Simplified_AdEMAMix.py
19
+ adv_optm/optim/SinkSGD_adv.py
20
20
  adv_optm/optim/__init__.py
21
21
  adv_optm/util/Kourkoutas.py
22
22
  adv_optm/util/Muon_AuxAdam.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="2.4.dev11",
8
+ version="2.4.dev13",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
@@ -1,42 +0,0 @@
1
- import math
2
- import torch
3
-
4
- def apply_sr_sinkhorn(update: torch.Tensor, iters: int = 5) -> torch.Tensor:
5
- """
6
- Applies Square-Root Sinkhorn (SR-Sinkhorn) multi-normalization.
7
- As described in 'Gradient Multi-Normalization for Efficient LLM Training'.
8
-
9
- This technique normalizes a 2D matrix alternatively by its row-wise L2 norm
10
- and column-wise L2 norm, driving it toward a fixed point that uniformly
11
- distributes update magnitudes.
12
- """
13
- original_shape = update.shape
14
-
15
- if update.dim() == 1:
16
- norm = update.norm(p=2).clamp_min_(1e-12)
17
- return update.mul_(math.sqrt(update.numel()) / norm)
18
- else:
19
- # Flatten >= 3D tensors into 2D matrices
20
- update_2d = update.view(update.shape[0], -1)
21
-
22
- m, n = update_2d.shape
23
-
24
- # Dynamically determine the order of normalization based on aspect ratio
25
- # Normalizing the longer dimension first aids stability.
26
- dim = 0 if m > n else 1
27
-
28
- # Precompute scaling factors.
29
- scale_first = math.sqrt(m) if dim == 0 else math.sqrt(n)
30
- scale_second = math.sqrt(n) if dim == 0 else math.sqrt(m)
31
-
32
- # In-place alternating Sinkhorn normalization steps
33
- for _ in range(iters):
34
- # First normalization step
35
- norm1 = update_2d.norm(p=2, dim=dim, keepdim=True).clamp_min_(1e-12)
36
- update_2d.mul_(scale_first / norm1)
37
-
38
- # Second normalization step
39
- norm2 = update_2d.norm(p=2, dim=1-dim, keepdim=True).clamp_min_(1e-12)
40
- update_2d.mul_(scale_second / norm2)
41
-
42
- return update_2d.view(original_shape)
File without changes
File without changes
File without changes