hyper-connections 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,6 +47,10 @@ def l1norm(t, dim):
47
47
  return F.normalize(t, p = 1, dim = dim)
48
48
 
49
49
  def sinkhorn_knopps(log_alpha, iters = 20):
50
+
51
+ if iters <= 0:
52
+ return log_alpha
53
+
50
54
  assert log_alpha.shape[-2] == log_alpha.shape[-1]
51
55
 
52
56
  dtype = log_alpha.dtype
@@ -63,6 +67,10 @@ def sinkhorn_knopps(log_alpha, iters = 20):
63
67
  return alpha.to(dtype)
64
68
 
65
69
  def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
70
+
71
+ if iters <= 0:
72
+ return log_alpha
73
+
66
74
  assert log_alpha.shape[-2] == log_alpha.shape[-1]
67
75
 
68
76
  dtype = log_alpha.dtype
@@ -116,8 +124,8 @@ def get_init_and_expand_reduce_stream_functions(
116
124
 
117
125
  hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
118
126
 
119
- kwargs.pop('add_attn_pool_reduce_stream', None)
120
127
  init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
128
+
121
129
  expand_reduce_fns = get_expand_reduce_stream_functions(
122
130
  num_streams,
123
131
  add_stream_embed = add_stream_embed,
hyper_connections/vit.py CHANGED
@@ -5,7 +5,7 @@ from torch.nn import Module, ModuleList
5
5
  from einops import rearrange, repeat
6
6
  from einops.layers.torch import Rearrange
7
7
 
8
- from hyper_connections.manifold_constrained_hyper_connections import mHC
8
+ from hyper_connections.mHCv2 import mHC
9
9
 
10
10
  # helpers
11
11
 
@@ -66,12 +66,12 @@ class Attention(Module):
66
66
  return self.to_out(out)
67
67
 
68
68
  class Transformer(Module):
69
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
69
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., num_residual_streams = 4, mhc_kwargs = dict()):
70
70
  super().__init__()
71
71
  self.norm = nn.LayerNorm(dim)
72
72
  self.layers = ModuleList([])
73
73
 
74
- init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_dynamic_alpha_proposals = num_dynamic_alpha_proposals)
74
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mHC.get_init_and_expand_reduce_stream_functions(num_residual_streams, **mhc_kwargs)
75
75
 
76
76
  for _ in range(depth):
77
77
  self.layers.append(ModuleList([
@@ -92,7 +92,7 @@ class Transformer(Module):
92
92
  return self.norm(x)
93
93
 
94
94
  class ViT(Module):
95
- def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, num_dynamic_alpha_proposals = 1):
95
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., num_residual_streams = 4, mhc_kwargs = dict(num_dynamic_alpha_proposals = 1)):
96
96
  super().__init__()
97
97
  image_height, image_width = pair(image_size)
98
98
  patch_height, patch_width = pair(patch_size)
@@ -117,7 +117,7 @@ class ViT(Module):
117
117
 
118
118
  self.dropout = nn.Dropout(emb_dropout)
119
119
 
120
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, num_dynamic_alpha_proposals)
120
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_residual_streams, mhc_kwargs)
121
121
 
122
122
  self.pool = pool
123
123
  self.to_latent = nn.Identity()
@@ -154,7 +154,10 @@ if __name__ == '__main__':
154
154
  mlp_dim = 2048,
155
155
  dropout = 0.1,
156
156
  emb_dropout = 0.1,
157
- num_residual_streams = 4
157
+ num_residual_streams = 4,
158
+ mhc_kwargs = dict(
159
+ use_triton_sinkhorn = False
160
+ )
158
161
  )
159
162
 
160
163
  img = torch.randn(1, 3, 256, 256)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.5
3
+ Version: 0.4.7
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,12 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
6
- hyper_connections/mHCv2.py,sha256=wCtp87OFI3QfosdSL-1qwsiQN9f8gX32_0r8GQGO7P0,17411
6
+ hyper_connections/mHCv2.py,sha256=XB2HwxTo7daZvy9fzF8UjHI12ephwgE91h9AH2Ou4WI,17452
7
7
  hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
8
8
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
9
9
  hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
10
- hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
11
- hyper_connections-0.4.5.dist-info/METADATA,sha256=sWVb_-yVRmxL8AsAPsk0VdRXOa25uG9zKNc8S_oAXg8,6704
12
- hyper_connections-0.4.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
- hyper_connections-0.4.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
- hyper_connections-0.4.5.dist-info/RECORD,,
10
+ hyper_connections/vit.py,sha256=dh8AVMUPaUHuWxXJEHoMW_G5nj-EQQjDmgbPwwhiq5g,5215
11
+ hyper_connections-0.4.7.dist-info/METADATA,sha256=2ajn-IuCxuUjgnOw5dEBxXKqLJbyQohHSGgKJ2dZFoA,6704
12
+ hyper_connections-0.4.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
+ hyper_connections-0.4.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
+ hyper_connections-0.4.7.dist-info/RECORD,,