PVNet 5.3.0__py3-none-any.whl → 5.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/models/late_fusion/site_encoders/encoders.py +2 -2
- {pvnet-5.3.0.dist-info → pvnet-5.3.1.dist-info}/METADATA +1 -1
- {pvnet-5.3.0.dist-info → pvnet-5.3.1.dist-info}/RECORD +6 -6
- {pvnet-5.3.0.dist-info → pvnet-5.3.1.dist-info}/WHEEL +0 -0
- {pvnet-5.3.0.dist-info → pvnet-5.3.1.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.3.0.dist-info → pvnet-5.3.1.dist-info}/top_level.txt +0 -0
|
@@ -158,7 +158,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
158
158
|
super().__init__(sequence_length, num_sites, out_features)
|
|
159
159
|
self.sequence_length = sequence_length
|
|
160
160
|
self.target_id_embedding = nn.Embedding(target_id_dim, out_features)
|
|
161
|
-
self.
|
|
161
|
+
self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim)
|
|
162
162
|
self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
|
|
163
163
|
self.use_id_in_value = use_id_in_value
|
|
164
164
|
self.key_to_use = key_to_use
|
|
@@ -224,7 +224,7 @@ class SingleAttentionNetwork(AbstractSitesEncoder):
|
|
|
224
224
|
site_seqs, batch_size = self._encode_inputs(x)
|
|
225
225
|
|
|
226
226
|
# site ID embeddings are the same for each sample
|
|
227
|
-
id_embed = torch.tile(self.
|
|
227
|
+
id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
|
|
228
228
|
# Each concated (site sequence, site ID embedding) is processed with encoder
|
|
229
229
|
x_seq_in = torch.cat((site_seqs, id_embed), dim=2).flatten(0, 1)
|
|
230
230
|
key = self._key_encoder(x_seq_in)
|
|
@@ -17,13 +17,13 @@ pvnet/models/late_fusion/linear_networks/basic_blocks.py,sha256=RnwdeuX_-itY4ncM
|
|
|
17
17
|
pvnet/models/late_fusion/linear_networks/networks.py,sha256=exEIz_Z85f8nSwcvp4wqiiLECEAg9YbkKhSZJvFy75M,2231
|
|
18
18
|
pvnet/models/late_fusion/site_encoders/__init__.py,sha256=QoUiiWWFf12vEpdkw0gO4TWpOEoI_tgAyUFCWFFpYAk,45
|
|
19
19
|
pvnet/models/late_fusion/site_encoders/basic_blocks.py,sha256=iEB_N7ZL5HMQ1hZM6H32A71GCwP7YbErUx0oQF21PQM,1042
|
|
20
|
-
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=
|
|
20
|
+
pvnet/models/late_fusion/site_encoders/encoders.py,sha256=PemEUa_Wv5pFWw3usPKEtXcvs_MX2LSrO6nhldO_QVk,11320
|
|
21
21
|
pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
22
22
|
pvnet/training/lightning_module.py,sha256=57sT7bPCU7mJw4EskzOE-JJ9JhWIuAbs40_x5RoBbA8,12705
|
|
23
23
|
pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
|
|
24
24
|
pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
|
|
25
|
-
pvnet-5.3.
|
|
26
|
-
pvnet-5.3.
|
|
27
|
-
pvnet-5.3.
|
|
28
|
-
pvnet-5.3.
|
|
29
|
-
pvnet-5.3.
|
|
25
|
+
pvnet-5.3.1.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
26
|
+
pvnet-5.3.1.dist-info/METADATA,sha256=LMfxIQEjnBwoJQktBq3DOEKYgcUUxaMD6k3s6vOBWiU,16479
|
|
27
|
+
pvnet-5.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
+
pvnet-5.3.1.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
29
|
+
pvnet-5.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|