pyg-nightly 2.7.0.dev20250221__py3-none-any.whl → 2.7.0.dev20250223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/RECORD +9 -9
- torch_geometric/__init__.py +1 -1
- torch_geometric/hash_tensor.py +186 -8
- torch_geometric/metrics/link_pred.py +26 -5
- torch_geometric/testing/__init__.py +2 -0
- torch_geometric/testing/decorators.py +12 -0
- {pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250223
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=u79QBiX3vYzq10QGHDb7To3rzMR1UXXopOn8ptXDZ0A,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -9,7 +9,7 @@ torch_geometric/deprecation.py,sha256=dWRymDIUkUVI2MeEmBG5WF4R6jObZeseSBV9G6FNfj
|
|
9
9
|
torch_geometric/device.py,sha256=tU5-_lBNVbVHl_kUmWPwiG5mQ1pyapwMF4JkmtNN3MM,1224
|
10
10
|
torch_geometric/edge_index.py,sha256=BsLh5tOZRjjSYDkjqOFAdBuvMaDg7EWaaLELYsUL0Z8,70048
|
11
11
|
torch_geometric/experimental.py,sha256=JbtNNEXjFGI8hZ9raM6-qrZURP6Z5nlDK8QicZUIbz0,4756
|
12
|
-
torch_geometric/hash_tensor.py,sha256=
|
12
|
+
torch_geometric/hash_tensor.py,sha256=9Zg1KCebfN-xJE1dX2nGGYnK09snSyJkjaYVzCUOfkM,17278
|
13
13
|
torch_geometric/home.py,sha256=EV54B4Dmiv61GDbkCwtCfWGWJ4eFGwZ8s3KOgGjwYgY,790
|
14
14
|
torch_geometric/index.py,sha256=9ChzWFCwj2slNcVBOgfV-wQn-KscJe_y7502w-Vf76w,24045
|
15
15
|
torch_geometric/inspector.py,sha256=nKi5o4Mn6xsG0Ex1GudTEQt_EqnF9mcMqGtp7Shh9sQ,19336
|
@@ -290,7 +290,7 @@ torch_geometric/loader/temporal_dataloader.py,sha256=AQ2QFeiXKbPp6I8sUeE8H7br-1_
|
|
290
290
|
torch_geometric/loader/utils.py,sha256=f27mczQ7fEP2HpTsJGJxKS0slPu0j8zTba3jP8ViNck,14901
|
291
291
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
292
292
|
torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
|
293
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
293
|
+
torch_geometric/metrics/link_pred.py,sha256=wGQG-Fl6BQYJMLZe_L_iIl4ixj6TWgLkkuHyMMraWBA,30480
|
294
294
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
295
295
|
torch_geometric/nn/data_parallel.py,sha256=lDAxRi83UNuzAQSj3eu9K2sQheOIU6wqR5elS6oDs90,4764
|
296
296
|
torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
|
@@ -513,10 +513,10 @@ torch_geometric/sampler/base.py,sha256=kT6hYM6losYta3pqLQlqiqboJiujLy6RlH8qM--U_
|
|
513
513
|
torch_geometric/sampler/hgt_sampler.py,sha256=UAm8_wwzEcziKDJ8-TnfZh1705dXRsy_I5PKhZSDTK8,2721
|
514
514
|
torch_geometric/sampler/neighbor_sampler.py,sha256=MAVphWqNf0-cwlHRvdiU8de86dBxwjm3Miam_6s1ep4,33971
|
515
515
|
torch_geometric/sampler/utils.py,sha256=RJtasO6Q7Pp3oYEOWrbf2DEYuSfuKZOsF2I7-eJDnoA,5485
|
516
|
-
torch_geometric/testing/__init__.py,sha256=
|
516
|
+
torch_geometric/testing/__init__.py,sha256=0mAGVWRrTBNsGV2YUkCu_FkyQ8JIcrYVw2LsdKgY9ak,1291
|
517
517
|
torch_geometric/testing/asserts.py,sha256=DLC9HnBgFWuTIiQs2OalsQcXGhOVG-e6R99IWhkO32c,4606
|
518
518
|
torch_geometric/testing/data.py,sha256=O1qo8FyNxt6RGf63Ys3eXBfa5RvYydeZLk74szrez3c,2604
|
519
|
-
torch_geometric/testing/decorators.py,sha256=
|
519
|
+
torch_geometric/testing/decorators.py,sha256=j45wlxMB1-Pn3wPKBgDziqg6KkWJUb_fcwfUXzkL2mM,8677
|
520
520
|
torch_geometric/testing/distributed.py,sha256=ZZCCXqiQC4-m1ExSjDZhS_a1qPXnHEwhJGTmACxNnVI,2227
|
521
521
|
torch_geometric/testing/feature_store.py,sha256=J6JBIt2XK-t8yG8B4JzXp-aJcVl5jaCS1m2H7d6OUxs,2158
|
522
522
|
torch_geometric/testing/graph_store.py,sha256=00B7QToCIspYmgN7svQKp1iU-qAzEtrt3VQRFxkHfuk,1044
|
@@ -633,7 +633,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
633
633
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
634
634
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
635
635
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
636
|
-
pyg_nightly-2.7.0.
|
637
|
-
pyg_nightly-2.7.0.
|
638
|
-
pyg_nightly-2.7.0.
|
639
|
-
pyg_nightly-2.7.0.
|
636
|
+
pyg_nightly-2.7.0.dev20250223.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
637
|
+
pyg_nightly-2.7.0.dev20250223.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82
|
638
|
+
pyg_nightly-2.7.0.dev20250223.dist-info/METADATA,sha256=o3vW1MbKajweST33mDeCk-b1CKb5wGegFomLfUE_rOQ,63021
|
639
|
+
pyg_nightly-2.7.0.dev20250223.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250223'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
torch_geometric/hash_tensor.py
CHANGED
@@ -142,12 +142,12 @@ class HashTensor(Tensor):
|
|
142
142
|
if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000
|
143
143
|
or _range <= 2 * key.numel()):
|
144
144
|
_map = torch.full(
|
145
|
-
size=(_range +
|
145
|
+
size=(_range + 3, ),
|
146
146
|
fill_value=-1,
|
147
147
|
dtype=torch.int64,
|
148
|
-
device=device,
|
148
|
+
device=key.device,
|
149
149
|
)
|
150
|
-
_map[(
|
150
|
+
_map[key.long() - (min_key.long() - 1)] = torch.arange(
|
151
151
|
key.numel(),
|
152
152
|
dtype=_map.dtype,
|
153
153
|
device=_map.device,
|
@@ -164,6 +164,8 @@ class HashTensor(Tensor):
|
|
164
164
|
dtype=dtype,
|
165
165
|
)
|
166
166
|
|
167
|
+
# Private Methods #########################################################
|
168
|
+
|
167
169
|
@classmethod
|
168
170
|
def _from_data(
|
169
171
|
cls,
|
@@ -207,6 +209,20 @@ class HashTensor(Tensor):
|
|
207
209
|
|
208
210
|
return out
|
209
211
|
|
212
|
+
@property
|
213
|
+
def _key(self) -> Tensor:
|
214
|
+
if isinstance(self._map, Tensor):
|
215
|
+
mask = self._map >= 0
|
216
|
+
key = mask.nonzero().view(-1) - 1
|
217
|
+
key = key[self._map[mask]]
|
218
|
+
elif (torch_geometric.typing.WITH_CUDA_HASH_MAP
|
219
|
+
or torch_geometric.typing.WITH_CPU_HASH_MAP):
|
220
|
+
key = self._map.keys().to(self.device)
|
221
|
+
else:
|
222
|
+
key = torch.from_numpy(self._map.categories.to_numpy())
|
223
|
+
|
224
|
+
return key.to(self.device)
|
225
|
+
|
210
226
|
def _shallow_copy(self) -> 'HashTensor':
|
211
227
|
return self._from_data(
|
212
228
|
self._map,
|
@@ -217,6 +233,37 @@ class HashTensor(Tensor):
|
|
217
233
|
dtype=self.dtype,
|
218
234
|
)
|
219
235
|
|
236
|
+
def _get(self, query: Tensor) -> Tensor:
|
237
|
+
if isinstance(self._map, Tensor):
|
238
|
+
index = query.long() - (self._min_key.long() - 1)
|
239
|
+
index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)]
|
240
|
+
elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda:
|
241
|
+
index = self._map.get(query)
|
242
|
+
elif torch_geometric.typing.WITH_CPU_HASH_MAP:
|
243
|
+
index = self._map.get(query.cpu())
|
244
|
+
else:
|
245
|
+
import pandas as pd
|
246
|
+
|
247
|
+
ser = pd.Series(query.cpu().numpy(), dtype=self._map)
|
248
|
+
index = torch.from_numpy(ser.cat.codes.to_numpy()).to(torch.long)
|
249
|
+
|
250
|
+
index = index.to(self.device)
|
251
|
+
|
252
|
+
if self._value is None:
|
253
|
+
return index.to(self.dtype)
|
254
|
+
|
255
|
+
out = self._value[index]
|
256
|
+
|
257
|
+
mask = index != -1
|
258
|
+
mask = mask.view([-1] + [1] * (out.dim() - 1))
|
259
|
+
fill_value = float('NaN') if out.is_floating_point() else -1
|
260
|
+
if torch_geometric.typing.WITH_PT20:
|
261
|
+
other: Union[int, float, Tensor] = fill_value
|
262
|
+
else:
|
263
|
+
other = torch.full_like(out, fill_value)
|
264
|
+
|
265
|
+
return out.where(mask, other)
|
266
|
+
|
220
267
|
# Methods #################################################################
|
221
268
|
|
222
269
|
def as_tensor(self) -> Tensor:
|
@@ -252,6 +299,23 @@ class HashTensor(Tensor):
|
|
252
299
|
kwargs)
|
253
300
|
return func(*args, **(kwargs or {}))
|
254
301
|
|
302
|
+
def tolist(self) -> List[Any]:
|
303
|
+
return self.as_tensor().tolist()
|
304
|
+
|
305
|
+
def index_select( # type: ignore
|
306
|
+
self,
|
307
|
+
dim: int,
|
308
|
+
index: Any,
|
309
|
+
) -> Union['HashTensor', Tensor]:
|
310
|
+
return torch.index_select(self, dim, index)
|
311
|
+
|
312
|
+
def select( # type: ignore
|
313
|
+
self,
|
314
|
+
dim: int,
|
315
|
+
index: Any,
|
316
|
+
) -> Union['HashTensor', Tensor]:
|
317
|
+
return torch.select(self, dim, index)
|
318
|
+
|
255
319
|
|
256
320
|
@implements(aten.alias.default)
|
257
321
|
def _alias(tensor: HashTensor) -> HashTensor:
|
@@ -327,9 +391,13 @@ def _squeeze_default(tensor: HashTensor) -> HashTensor:
|
|
327
391
|
if tensor._value is None:
|
328
392
|
return tensor._shallow_copy()
|
329
393
|
|
394
|
+
value = tensor.as_tensor()
|
395
|
+
for d in range(tensor.dim() - 1, 0, -1):
|
396
|
+
value = value.squeeze(d)
|
397
|
+
|
330
398
|
return tensor._from_data(
|
331
399
|
tensor._map,
|
332
|
-
|
400
|
+
value,
|
333
401
|
tensor._min_key,
|
334
402
|
tensor._max_key,
|
335
403
|
num_keys=tensor.size(0),
|
@@ -355,11 +423,14 @@ def _squeeze_dim(
|
|
355
423
|
if tensor._value is None:
|
356
424
|
return tensor._shallow_copy()
|
357
425
|
|
358
|
-
|
426
|
+
value = tensor.as_tensor()
|
427
|
+
for d in dim[::-1]:
|
428
|
+
if d != 0 and d != -tensor.dim():
|
429
|
+
value = value.squeeze(d)
|
359
430
|
|
360
431
|
return tensor._from_data(
|
361
432
|
tensor._map,
|
362
|
-
|
433
|
+
value,
|
363
434
|
tensor._min_key,
|
364
435
|
tensor._max_key,
|
365
436
|
num_keys=tensor.size(0),
|
@@ -374,10 +445,18 @@ def _slice(
|
|
374
445
|
start: Optional[int] = None,
|
375
446
|
end: Optional[int] = None,
|
376
447
|
step: int = 1,
|
377
|
-
) ->
|
448
|
+
) -> HashTensor:
|
378
449
|
|
379
450
|
if dim == 0 or dim == -tensor.dim():
|
380
|
-
|
451
|
+
copy = start is None or (start == 0 or start <= -tensor.size(0))
|
452
|
+
copy &= end is None or end > tensor.size(0)
|
453
|
+
copy &= step == 1
|
454
|
+
if copy:
|
455
|
+
return tensor._shallow_copy()
|
456
|
+
|
457
|
+
key = aten.slice.Tensor(tensor._key, 0, start, end, step)
|
458
|
+
value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step)
|
459
|
+
return tensor.__class__(key, value)
|
381
460
|
|
382
461
|
return tensor._from_data(
|
383
462
|
tensor._map,
|
@@ -387,3 +466,102 @@ def _slice(
|
|
387
466
|
num_keys=tensor.size(0),
|
388
467
|
dtype=tensor.dtype,
|
389
468
|
)
|
469
|
+
|
470
|
+
|
471
|
+
# Since PyTorch does only allow PyTorch tensors as indices in `index_select`,
|
472
|
+
# we need to create a wrapper function and monkey patch `index_select` :(
|
473
|
+
_old_index_select = torch.index_select
|
474
|
+
|
475
|
+
|
476
|
+
def _new_index_select(
|
477
|
+
input: Tensor,
|
478
|
+
dim: int,
|
479
|
+
index: Tensor,
|
480
|
+
*,
|
481
|
+
out: Optional[Tensor] = None,
|
482
|
+
) -> Tensor:
|
483
|
+
|
484
|
+
if dim < -input.dim() or dim >= input.dim():
|
485
|
+
raise IndexError(f"Dimension out of range (expected to be in range of "
|
486
|
+
f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
|
487
|
+
|
488
|
+
# We convert any index tensor in the first dimension into a tensor. This
|
489
|
+
# means that downstream handling (i.e. in `aten.index_select.default`)
|
490
|
+
# needs to take this pre-conversion into account.
|
491
|
+
if isinstance(input, HashTensor) and (dim == 0 or dim == -input.dim()):
|
492
|
+
index = as_key_tensor(index, device=input.device)
|
493
|
+
return _old_index_select(input, dim, index, out=out)
|
494
|
+
|
495
|
+
|
496
|
+
torch.index_select = _new_index_select # type: ignore
|
497
|
+
|
498
|
+
|
499
|
+
@implements(aten.index_select.default)
|
500
|
+
def _index_select(
|
501
|
+
tensor: HashTensor,
|
502
|
+
dim: int,
|
503
|
+
index: Tensor,
|
504
|
+
) -> Union[HashTensor, Tensor]:
|
505
|
+
|
506
|
+
if dim == 0 or dim == -tensor.dim():
|
507
|
+
return tensor._get(index)
|
508
|
+
|
509
|
+
return tensor._from_data(
|
510
|
+
tensor._map,
|
511
|
+
aten.index_select.default(tensor.as_tensor(), dim, index),
|
512
|
+
tensor._min_key,
|
513
|
+
tensor._max_key,
|
514
|
+
num_keys=tensor.size(0),
|
515
|
+
dtype=tensor.dtype,
|
516
|
+
)
|
517
|
+
|
518
|
+
|
519
|
+
# Since PyTorch does only allow PyTorch tensors as indices in `select`, we need
|
520
|
+
# to create a wrapper function and monkey patch `select` :(
|
521
|
+
_old_select = torch.select
|
522
|
+
|
523
|
+
|
524
|
+
def _new_select(
|
525
|
+
input: Tensor,
|
526
|
+
dim: int,
|
527
|
+
index: int,
|
528
|
+
) -> Tensor:
|
529
|
+
|
530
|
+
if dim < -input.dim() or dim >= input.dim():
|
531
|
+
raise IndexError(f"Dimension out of range (expected to be in range of "
|
532
|
+
f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
|
533
|
+
|
534
|
+
# We convert any index in the first dimension into an integer. This means
|
535
|
+
# that downstream handling (i.e. in `aten.select.int`) needs to take this
|
536
|
+
# pre-conversion into account.
|
537
|
+
if isinstance(input, HashTensor) and (dim == 0 or dim == -input.dim()):
|
538
|
+
index = int(as_key_tensor([index]))
|
539
|
+
return _old_select(input, dim, index)
|
540
|
+
|
541
|
+
|
542
|
+
torch.select = _new_select # type: ignore
|
543
|
+
|
544
|
+
|
545
|
+
@implements(aten.select.int)
|
546
|
+
def _select(
|
547
|
+
tensor: HashTensor,
|
548
|
+
dim: int,
|
549
|
+
index: int,
|
550
|
+
) -> Union[HashTensor, Tensor]:
|
551
|
+
|
552
|
+
if dim == 0 or dim == -tensor.dim():
|
553
|
+
key = torch.tensor(
|
554
|
+
[index],
|
555
|
+
dtype=tensor._min_key.dtype,
|
556
|
+
device=tensor._min_key.device,
|
557
|
+
)
|
558
|
+
return tensor._get(key).squeeze(0)
|
559
|
+
|
560
|
+
return tensor._from_data(
|
561
|
+
tensor._map,
|
562
|
+
aten.select.int(tensor.as_tensor(), dim, index),
|
563
|
+
tensor._min_key,
|
564
|
+
tensor._max_key,
|
565
|
+
num_keys=tensor.size(0),
|
566
|
+
dtype=tensor.dtype,
|
567
|
+
)
|
@@ -715,22 +715,32 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
715
715
|
|
716
716
|
Args:
|
717
717
|
k (int): The number of top-:math:`k` predictions to evaluate against.
|
718
|
+
max_src_nodes (int, optional): The maximum source nodes to consider to
|
719
|
+
compute pair-wise dissimilarity. If specified,
|
720
|
+
Personalization @ :math:`k` is approximated to avoid computation
|
721
|
+
blowup due to quadratic complexity. (default: :obj:`2**12`)
|
718
722
|
batch_size (int, optional): The batch size to determine how many pairs
|
719
723
|
of user recommendations should be processed at once.
|
720
724
|
(default: :obj:`2**16`)
|
721
725
|
"""
|
722
726
|
higher_is_better: bool = True
|
723
727
|
|
724
|
-
def __init__(
|
728
|
+
def __init__(
|
729
|
+
self,
|
730
|
+
k: int,
|
731
|
+
max_src_nodes: Optional[int] = 2**12,
|
732
|
+
batch_size: int = 2**16,
|
733
|
+
) -> None:
|
725
734
|
super().__init__(k)
|
735
|
+
self.max_src_nodes = max_src_nodes
|
726
736
|
self.batch_size = batch_size
|
727
737
|
|
728
738
|
if WITH_TORCHMETRICS:
|
729
739
|
self.add_state('preds', default=[], dist_reduce_fx='cat')
|
730
|
-
self.add_state('
|
740
|
+
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
|
731
741
|
else:
|
732
742
|
self.preds: List[Tensor] = []
|
733
|
-
self.register_buffer('
|
743
|
+
self.register_buffer('total', torch.tensor(0))
|
734
744
|
|
735
745
|
def update(
|
736
746
|
self,
|
@@ -738,11 +748,21 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
738
748
|
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
|
739
749
|
edge_label_weight: Optional[Tensor] = None,
|
740
750
|
) -> None:
|
751
|
+
|
741
752
|
# NOTE Move to CPU to avoid memory blowup.
|
742
|
-
|
753
|
+
pred_index_mat = pred_index_mat[:, :self.k].cpu()
|
754
|
+
|
755
|
+
if self.max_src_nodes is None:
|
756
|
+
self.preds.append(pred_index_mat)
|
757
|
+
self.total += pred_index_mat.size(0)
|
758
|
+
elif self.total < self.max_src_nodes:
|
759
|
+
remaining = int(self.max_src_nodes - self.total)
|
760
|
+
pred_index_mat = pred_index_mat[:remaining]
|
761
|
+
self.preds.append(pred_index_mat)
|
762
|
+
self.total += pred_index_mat.size(0)
|
743
763
|
|
744
764
|
def compute(self) -> Tensor:
|
745
|
-
device = self.
|
765
|
+
device = self.total.device
|
746
766
|
score = torch.tensor(0.0, device=device)
|
747
767
|
total = torch.tensor(0, device=device)
|
748
768
|
|
@@ -786,6 +806,7 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
786
806
|
|
787
807
|
def _reset(self) -> None:
|
788
808
|
self.preds = []
|
809
|
+
self.total.zero_()
|
789
810
|
|
790
811
|
|
791
812
|
class LinkPredAveragePopularity(_LinkPredMetric):
|
@@ -22,6 +22,7 @@ from .decorators import (
|
|
22
22
|
withDevice,
|
23
23
|
withCUDA,
|
24
24
|
withMETIS,
|
25
|
+
withHashTensor,
|
25
26
|
disableExtensions,
|
26
27
|
withoutExtensions,
|
27
28
|
)
|
@@ -53,6 +54,7 @@ __all__ = [
|
|
53
54
|
'withDevice',
|
54
55
|
'withCUDA',
|
55
56
|
'withMETIS',
|
57
|
+
'withHashTensor',
|
56
58
|
'disableExtensions',
|
57
59
|
'withoutExtensions',
|
58
60
|
'assert_module',
|
@@ -10,6 +10,7 @@ from packaging.requirements import Requirement
|
|
10
10
|
from packaging.version import Version
|
11
11
|
|
12
12
|
import torch_geometric
|
13
|
+
import torch_geometric.typing
|
13
14
|
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
|
14
15
|
from torch_geometric.visualization.graph import has_graphviz
|
15
16
|
|
@@ -265,6 +266,17 @@ def withMETIS(func: Callable) -> Callable:
|
|
265
266
|
)(func)
|
266
267
|
|
267
268
|
|
269
|
+
def withHashTensor(func: Callable) -> Callable:
|
270
|
+
r"""A decorator to only test in case :class:`HashTensor` is available."""
|
271
|
+
import pytest
|
272
|
+
|
273
|
+
return pytest.mark.skipif(
|
274
|
+
not torch_geometric.typing.WITH_CPU_HASH_MAP
|
275
|
+
and not has_package('pandas'),
|
276
|
+
reason="HashTensor dependencies not available",
|
277
|
+
)(func)
|
278
|
+
|
279
|
+
|
268
280
|
def disableExtensions(func: Callable) -> Callable:
|
269
281
|
r"""A decorator to temporarily disable the usage of the
|
270
282
|
:obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
|
File without changes
|
{pyg_nightly-2.7.0.dev20250221.dist-info → pyg_nightly-2.7.0.dev20250223.dist-info}/licenses/LICENSE
RENAMED
File without changes
|