hyper-connections 0.4.6__tar.gz → 0.4.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/PKG-INFO +1 -1
  2. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/mHCv2.py +1 -1
  3. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/vit.py +9 -6
  4. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/pyproject.toml +1 -1
  5. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/.github/workflows/python-publish.yml +0 -0
  6. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/.github/workflows/test.yml +0 -0
  7. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/.gitignore +0 -0
  8. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/LICENSE +0 -0
  9. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/README.md +0 -0
  10. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper-connections.png +0 -0
  11. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/__init__.py +0 -0
  12. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/hyper_connections.py +0 -0
  13. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/hyper_connections_channel_first.py +0 -0
  14. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  15. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  16. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/manifold_constrained_hyper_connections.py +0 -0
  17. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/residuals.py +0 -0
  18. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/hyper_connections/triton_sinkhorn.py +0 -0
  19. {hyper_connections-0.4.6 → hyper_connections-0.4.7}/tests/test_hyper_connections.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.6
3
+ Version: 0.4.7
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -124,8 +124,8 @@ def get_init_and_expand_reduce_stream_functions(
124
124
 
125
125
  hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
126
126
 
127
- kwargs.pop('add_attn_pool_reduce_stream', None)
128
127
  init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
128
+
129
129
  expand_reduce_fns = get_expand_reduce_stream_functions(
130
130
  num_streams,
131
131
  add_stream_embed = add_stream_embed,
@@ -5,7 +5,7 @@ from torch.nn import Module, ModuleList
5
5
  from einops import rearrange, repeat
6
6
  from einops.layers.torch import Rearrange
7
7
 
8
- from hyper_connections.manifold_constrained_hyper_connections import mHC
8
+ from hyper_connections.mHCv2 import mHC
9
9
 
10
10
  # helpers
11
11
 
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, mhc_kwargs = dict()):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, **mhc_kwargs)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, mhc_kwargs = dict(num_dynamic_alpha_proposals = 1)):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, mhc_kwargs)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -154,7 +154,10 @@ if __name__ == '__main__':
154
154
  mlp_dim = 2048,
155
155
  dropout = 0.1,
156
156
  emb_dropout = 0.1,
157
- num_residual_streams = 4
157
+ num_residual_streams = 4,
158
+ mhc_kwargs = dict(
159
+ use_triton_sinkhorn = False
160
+ )
158
161
  )
159
162
 
160
163
  img = torch.randn(1, 3, 256, 256)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.4.6"
3
+ version = "0.4.7"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }