hyper-connections 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/PKG-INFO +1 -1
  2. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/manifold_constrained_hyper_connections.py +10 -7
  3. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/pyproject.toml +1 -1
  4. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/.github/workflows/python-publish.yml +0 -0
  5. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/.github/workflows/test.yml +0 -0
  6. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/.gitignore +0 -0
  7. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/LICENSE +0 -0
  8. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/README.md +0 -0
  9. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper-connections.png +0 -0
  10. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/__init__.py +0 -0
  11. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/hyper_connections.py +0 -0
  12. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/hyper_connections_channel_first.py +0 -0
  13. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  14. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  15. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/residuals.py +0 -0
  16. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/hyper_connections/vit.py +0 -0
  17. {hyper_connections-0.4.0 → hyper_connections-0.4.2}/tests/test_hyper_connections.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -213,8 +213,10 @@ class ManifoldConstrainedHyperConnections(Module):
213
213
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
214
214
  sinkhorn_iters = 20,
215
215
  log_domain_sinkhorn = False,
216
+ residual_mix_constraint_fn: Callable | None = None,
216
217
  forward_method_names: tuple[str, ...] = (),
217
- num_dynamic_alpha_proposals = 1
218
+ num_dynamic_alpha_proposals = 1,
219
+
218
220
  ):
219
221
  """
220
222
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -287,10 +289,13 @@ class ManifoldConstrainedHyperConnections(Module):
287
289
 
288
290
  self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
289
291
 
290
- # sinkhorn related
292
+ # Hres constraint related
293
+ # by default is sinkhorn
291
294
 
292
- self.sinkhorn_iters = sinkhorn_iters
293
- self.log_domain_sinkhorn = log_domain_sinkhorn
295
+ self.residual_mix_constraint_fn = default(
296
+ residual_mix_constraint_fn,
297
+ partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
298
+ )
294
299
 
295
300
  # dropouts
296
301
 
@@ -373,9 +378,7 @@ class ManifoldConstrainedHyperConnections(Module):
373
378
 
374
379
  alpha_pre = alpha_pre.sigmoid()
375
380
 
376
- sinkhorn_fn = sinkhorn_knopps if not self.log_domain_sinkhorn else log_domain_sinkhorn_knopps
377
-
378
- alpha_residual = sinkhorn_fn(alpha_residual, self.sinkhorn_iters)
381
+ alpha_residual = self.residual_mix_constraint_fn(alpha_residual)
379
382
 
380
383
  alpha = cat((alpha_pre, alpha_residual), dim = -1)
381
384
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.4.0"
3
+ version = "0.4.2"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }