pyg-nightly 2.7.0.dev20250918__py3-none-any.whl → 2.7.0.dev20250920__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/RECORD +8 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/datasets/tag_dataset.py +134 -30
- torch_geometric/loader/utils.py +8 -8
- torch_geometric/metrics/link_pred.py +5 -4
- {pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250920
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=0ij8VxVSK4T5A19Dr05CrBi0vbfTv2d0vTpB73hQsws,2292
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -151,7 +151,7 @@ torch_geometric/datasets/shapenet.py,sha256=tn3HiQQAr6lxHrqxfOVaAtl40guwFYTXWCbS
|
|
|
151
151
|
torch_geometric/datasets/shrec2016.py,sha256=cTLhctbqE0EUEvKddJFhPzDb1oLKXOth4O_WzsWtyMk,6323
|
|
152
152
|
torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7ZSUci645X0U,9481
|
|
153
153
|
torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
|
|
154
|
-
torch_geometric/datasets/tag_dataset.py,sha256=
|
|
154
|
+
torch_geometric/datasets/tag_dataset.py,sha256=jslijGCh37ip2YkrQLyvbk-1QRJ3yqFpmzuQSxckXrE,19402
|
|
155
155
|
torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
|
|
156
156
|
torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
|
|
157
157
|
torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
|
|
@@ -308,10 +308,10 @@ torch_geometric/loader/prefetch.py,sha256=z30TIcu3_6ZubllUOwNLunlq4RyQdFj36vPE5Q
|
|
|
308
308
|
torch_geometric/loader/random_node_loader.py,sha256=rCmRXYv70SPxBo-Oh049eFEWEZDV7FmlRPzmjcoirXQ,2196
|
|
309
309
|
torch_geometric/loader/shadow.py,sha256=_hCspYf9SlJYX0lqEjxFec9e9t1iMScNThOoWR1wQGM,4173
|
|
310
310
|
torch_geometric/loader/temporal_dataloader.py,sha256=Z7L_rYdl6SYBQXAgtr18FVcmfMH9kP1fBWrc2W63g2c,2250
|
|
311
|
-
torch_geometric/loader/utils.py,sha256=
|
|
311
|
+
torch_geometric/loader/utils.py,sha256=DgGHK6kNu7ZZIZuaT0Ya_4rUctBMMKyBBSdHhuU389w,14903
|
|
312
312
|
torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
|
|
313
313
|
torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
|
|
314
|
-
torch_geometric/metrics/link_pred.py,sha256=
|
|
314
|
+
torch_geometric/metrics/link_pred.py,sha256=bacmFGn7rm0iF2wOJdAW-iTZ04bOuiS-7ur2K-MZKlA,31684
|
|
315
315
|
torch_geometric/nn/__init__.py,sha256=tTEKDy4vpjPNKyG1Vg9GIx7dVFJuQtBoh2M19ascGpo,880
|
|
316
316
|
torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
|
|
317
317
|
torch_geometric/nn/encoding.py,sha256=82fpwyOx0-STFSAJ5AzG0p2WFC9u1M4KgmKIql8hSLc,3634
|
|
@@ -654,7 +654,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
|
654
654
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
655
655
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
656
656
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
657
|
-
pyg_nightly-2.7.0.
|
|
658
|
-
pyg_nightly-2.7.0.
|
|
659
|
-
pyg_nightly-2.7.0.
|
|
660
|
-
pyg_nightly-2.7.0.
|
|
657
|
+
pyg_nightly-2.7.0.dev20250920.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
658
|
+
pyg_nightly-2.7.0.dev20250920.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
659
|
+
pyg_nightly-2.7.0.dev20250920.dist-info/METADATA,sha256=PAeahjszlJpaI4WHs-eZPOYELiodtDDAPudxTK4MfTA,64145
|
|
660
|
+
pyg_nightly-2.7.0.dev20250920.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250920'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import csv
|
|
1
2
|
import os
|
|
2
3
|
import os.path as osp
|
|
3
4
|
from collections.abc import Sequence
|
|
@@ -10,6 +11,7 @@ from tqdm import tqdm
|
|
|
10
11
|
|
|
11
12
|
from torch_geometric.data import InMemoryDataset, download_google_url
|
|
12
13
|
from torch_geometric.data.data import BaseData
|
|
14
|
+
from torch_geometric.io import fs
|
|
13
15
|
|
|
14
16
|
try:
|
|
15
17
|
from pandas import DataFrame, read_csv
|
|
@@ -22,14 +24,16 @@ IndexType = Union[slice, Tensor, np.ndarray, Sequence]
|
|
|
22
24
|
|
|
23
25
|
class TAGDataset(InMemoryDataset):
|
|
24
26
|
r"""The Text Attributed Graph datasets from the
|
|
25
|
-
`"Learning on Large-scale Text-attributed Graphs via Variational Inference
|
|
26
|
-
|
|
27
|
+
`"Learning on Large-scale Text-attributed Graphs via Variational Inference"
|
|
28
|
+
<https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
|
|
29
|
+
LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
|
|
30
|
+
Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
|
|
27
31
|
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
|
|
28
32
|
into Text Attributed Graph that each node in graph is associate with a
|
|
29
|
-
raw text, that dataset can be adapt to
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
Text Attributed Graph.
|
|
33
|
+
raw text, LLM prediction and explanation, that dataset can be adapt to
|
|
34
|
+
DataLoader (for LM training) and NeighborLoader(for GNN training).
|
|
35
|
+
In addition, this class can be use as a wrapper class by convert a
|
|
36
|
+
InMemoryDataset with Tokenizer and text into Text Attributed Graph.
|
|
33
37
|
|
|
34
38
|
Args:
|
|
35
39
|
root (str): Root directory where the dataset should be saved.
|
|
@@ -51,22 +55,35 @@ class TAGDataset(InMemoryDataset):
|
|
|
51
55
|
or not, default: False
|
|
52
56
|
force_reload (bool): default: False
|
|
53
57
|
.. note::
|
|
54
|
-
See `example/
|
|
58
|
+
See `example/llm/glem.py` for example usage
|
|
55
59
|
"""
|
|
56
60
|
raw_text_id = {
|
|
57
61
|
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
|
|
58
62
|
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
|
|
59
63
|
}
|
|
60
64
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'
|
|
66
|
+
|
|
67
|
+
llm_explanation_id = {
|
|
68
|
+
'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
root: str,
|
|
74
|
+
dataset: InMemoryDataset,
|
|
75
|
+
tokenizer_name: str,
|
|
76
|
+
text: Optional[List[str]] = None,
|
|
77
|
+
split_idx: Optional[Dict[str, Tensor]] = None,
|
|
78
|
+
tokenize_batch_size: int = 256,
|
|
79
|
+
token_on_disk: bool = False,
|
|
80
|
+
text_on_disk: bool = False,
|
|
81
|
+
force_reload: bool = False,
|
|
82
|
+
) -> None:
|
|
67
83
|
# list the vars you want to pass in before run download & process
|
|
68
84
|
self.name = dataset.name
|
|
69
85
|
self.text = text
|
|
86
|
+
self.llm_prediction_topk = 5
|
|
70
87
|
self.tokenizer_name = tokenizer_name
|
|
71
88
|
from transformers import AutoTokenizer
|
|
72
89
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
@@ -93,8 +110,9 @@ class TAGDataset(InMemoryDataset):
|
|
|
93
110
|
"is_gold mask, please pass splited index "
|
|
94
111
|
"in format of dictionaty with 'train', 'valid' "
|
|
95
112
|
"'test' index tensor to 'split_idx'")
|
|
96
|
-
if
|
|
97
|
-
|
|
113
|
+
if text_on_disk:
|
|
114
|
+
if text is not None:
|
|
115
|
+
self.save_node_text(text)
|
|
98
116
|
self.text_on_disk = text_on_disk
|
|
99
117
|
# init will call download and process
|
|
100
118
|
super().__init__(self.root, transform=None, pre_transform=None,
|
|
@@ -119,6 +137,10 @@ class TAGDataset(InMemoryDataset):
|
|
|
119
137
|
self.token_on_disk = token_on_disk
|
|
120
138
|
self.tokenize_batch_size = tokenize_batch_size
|
|
121
139
|
self._token = self.tokenize_graph(self.tokenize_batch_size)
|
|
140
|
+
self._llm_explanation_token = self.tokenize_graph(
|
|
141
|
+
self.tokenize_batch_size, text_type='llm_explanation')
|
|
142
|
+
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
|
|
143
|
+
text_type='all')
|
|
122
144
|
self.__num_classes__ = dataset.num_classes
|
|
123
145
|
|
|
124
146
|
@property
|
|
@@ -146,6 +168,19 @@ class TAGDataset(InMemoryDataset):
|
|
|
146
168
|
self._token = self.tokenize_graph()
|
|
147
169
|
return self._token
|
|
148
170
|
|
|
171
|
+
@property
|
|
172
|
+
def llm_explanation_token(self) -> Dict[str, Tensor]:
|
|
173
|
+
if self._llm_explanation_token is None: # lazy load
|
|
174
|
+
self._llm_explanation_token = self.tokenize_graph(
|
|
175
|
+
text_type='llm_explanation')
|
|
176
|
+
return self._llm_explanation_token
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def all_token(self) -> Dict[str, Tensor]:
|
|
180
|
+
if self._all_token is None: # lazy load
|
|
181
|
+
self._all_token = self.tokenize_graph(text_type='all')
|
|
182
|
+
return self._all_token
|
|
183
|
+
|
|
149
184
|
# load is_gold after init
|
|
150
185
|
@property
|
|
151
186
|
def is_gold(self) -> Tensor:
|
|
@@ -194,10 +229,17 @@ class TAGDataset(InMemoryDataset):
|
|
|
194
229
|
folder=f'{self.root}/raw',
|
|
195
230
|
filename='node-text.csv.gz',
|
|
196
231
|
log=True)
|
|
197
|
-
|
|
198
|
-
|
|
232
|
+
self.text = list(read_csv(raw_text_path)['text'])
|
|
233
|
+
print('downloading llm explanations')
|
|
234
|
+
llm_explanation_path = download_google_url(
|
|
235
|
+
id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw',
|
|
236
|
+
filename='node-gpt-response.csv.gz', log=True)
|
|
237
|
+
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
|
|
238
|
+
print('downloading llm predictions')
|
|
239
|
+
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
|
|
199
240
|
|
|
200
241
|
def process(self) -> None:
|
|
242
|
+
# process Title and Abstraction
|
|
201
243
|
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
|
|
202
244
|
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
|
|
203
245
|
self.text = list(text_df['text'])
|
|
@@ -212,6 +254,42 @@ class TAGDataset(InMemoryDataset):
|
|
|
212
254
|
"The raw text of each node is not specified"
|
|
213
255
|
"Please pass in 'text' when convert your dataset "
|
|
214
256
|
"to Text Attribute Graph Dataset")
|
|
257
|
+
# process LLM explanation and prediction
|
|
258
|
+
llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
|
|
259
|
+
llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
|
|
260
|
+
if osp.exists(llm_explanation_path) and osp.exists(
|
|
261
|
+
llm_prediction_path):
|
|
262
|
+
# load LLM explanation
|
|
263
|
+
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
|
|
264
|
+
# load LLM prediction
|
|
265
|
+
preds = []
|
|
266
|
+
with open(llm_prediction_path) as file:
|
|
267
|
+
reader = csv.reader(file)
|
|
268
|
+
for row in reader:
|
|
269
|
+
inner_list = []
|
|
270
|
+
for value in row:
|
|
271
|
+
inner_list.append(int(value))
|
|
272
|
+
preds.append(inner_list)
|
|
273
|
+
|
|
274
|
+
pl = torch.zeros(len(preds), self.llm_prediction_topk,
|
|
275
|
+
dtype=torch.long)
|
|
276
|
+
for i, pred in enumerate(preds):
|
|
277
|
+
pl[i][:len(pred)] = torch.tensor(
|
|
278
|
+
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
|
|
279
|
+
elif self.name in self.llm_explanation_id:
|
|
280
|
+
self.download()
|
|
281
|
+
else:
|
|
282
|
+
print(
|
|
283
|
+
'The dataset is not ogbn-arxiv,'
|
|
284
|
+
'please pass in your llm explanation list to `llm_explanation`'
|
|
285
|
+
'and llm prediction list to `llm_prediction`')
|
|
286
|
+
if self.llm_explanation is None or pl is None:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"The TAGDataset only have ogbn-arxiv LLM explanations"
|
|
289
|
+
"and predictions in default. The llm explanation and"
|
|
290
|
+
"prediction of each node is not specified."
|
|
291
|
+
"Please pass in 'llm_explanation' and 'llm_prediction' when"
|
|
292
|
+
"convert your dataset to Text Attribute Graph Dataset")
|
|
215
293
|
|
|
216
294
|
def save_node_text(self, text: List[str]) -> None:
|
|
217
295
|
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
|
|
@@ -224,22 +302,39 @@ class TAGDataset(InMemoryDataset):
|
|
|
224
302
|
text_df.to_csv(osp.join(node_text_path), compression='gzip',
|
|
225
303
|
index=False)
|
|
226
304
|
|
|
227
|
-
def tokenize_graph(self, batch_size: int = 256
|
|
305
|
+
def tokenize_graph(self, batch_size: int = 256,
|
|
306
|
+
text_type: str = 'raw_text') -> Dict[str, Tensor]:
|
|
228
307
|
r"""Tokenizing the text associate with each node, running in cpu.
|
|
229
308
|
|
|
230
309
|
Args:
|
|
231
310
|
batch_size (Optional[int]): batch size of list of text for
|
|
232
311
|
generating emebdding
|
|
312
|
+
text_type (Optional[str]): type of text
|
|
233
313
|
Returns:
|
|
234
314
|
Dict[str, torch.Tensor]: tokenized graph
|
|
235
315
|
"""
|
|
316
|
+
assert text_type in ['raw_text', 'llm_explanation', 'all']
|
|
317
|
+
if text_type == 'raw_text':
|
|
318
|
+
_text = self.text
|
|
319
|
+
elif text_type == 'llm_explanation':
|
|
320
|
+
_text = self.llm_explanation
|
|
321
|
+
elif text_type == 'all':
|
|
322
|
+
if self.text is None or self.llm_explanation is None:
|
|
323
|
+
raise ValueError("The TAGDataset need text and llm explanation"
|
|
324
|
+
"for tokenizing all text")
|
|
325
|
+
_text = [
|
|
326
|
+
f'{raw_txt} Explanation: {exp_txt}'
|
|
327
|
+
for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
|
|
328
|
+
]
|
|
329
|
+
|
|
236
330
|
data_len = 0
|
|
237
|
-
if
|
|
238
|
-
data_len = len(
|
|
331
|
+
if _text is not None:
|
|
332
|
+
data_len = len(_text)
|
|
239
333
|
else:
|
|
240
334
|
raise ValueError("The TAGDataset need text for tokenization")
|
|
241
335
|
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
|
242
|
-
path = os.path.join(self.processed_dir, 'token',
|
|
336
|
+
path = os.path.join(self.processed_dir, 'token', text_type,
|
|
337
|
+
self.tokenizer_name)
|
|
243
338
|
# Check if the .pt files already exist
|
|
244
339
|
token_files_exist = any(
|
|
245
340
|
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
|
|
@@ -256,12 +351,12 @@ class TAGDataset(InMemoryDataset):
|
|
|
256
351
|
all_encoded_token = {k: [] for k in token_keys}
|
|
257
352
|
pbar = tqdm(total=data_len)
|
|
258
353
|
|
|
259
|
-
pbar.set_description('Tokenizing Text Attributed Graph')
|
|
354
|
+
pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
|
|
260
355
|
for i in range(0, data_len, batch_size):
|
|
261
356
|
end_index = min(data_len, i + batch_size)
|
|
262
|
-
token = self.tokenizer(
|
|
263
|
-
|
|
264
|
-
|
|
357
|
+
token = self.tokenizer(_text[i:end_index], padding='max_length',
|
|
358
|
+
truncation=True, max_length=512,
|
|
359
|
+
return_tensors="pt")
|
|
265
360
|
for k in token.keys():
|
|
266
361
|
all_encoded_token[k].append(token[k])
|
|
267
362
|
pbar.update(end_index - i)
|
|
@@ -289,10 +384,18 @@ class TAGDataset(InMemoryDataset):
|
|
|
289
384
|
|
|
290
385
|
Args:
|
|
291
386
|
tag_dataset (TAGDataset): the parent dataset
|
|
387
|
+
text_type (str): type of text
|
|
292
388
|
"""
|
|
293
|
-
def __init__(self, tag_dataset: 'TAGDataset'
|
|
389
|
+
def __init__(self, tag_dataset: 'TAGDataset',
|
|
390
|
+
text_type: str = 'raw_text') -> None:
|
|
391
|
+
assert text_type in ['raw_text', 'llm_explanation', 'all']
|
|
294
392
|
self.tag_dataset = tag_dataset
|
|
295
|
-
|
|
393
|
+
if text_type == 'raw_text':
|
|
394
|
+
self.token = tag_dataset.token
|
|
395
|
+
elif text_type == 'llm_explanation':
|
|
396
|
+
self.token = tag_dataset.llm_explanation_token
|
|
397
|
+
elif text_type == 'all':
|
|
398
|
+
self.token = tag_dataset.all_token
|
|
296
399
|
assert tag_dataset._data is not None
|
|
297
400
|
self._data = tag_dataset._data
|
|
298
401
|
|
|
@@ -312,7 +415,8 @@ class TAGDataset(InMemoryDataset):
|
|
|
312
415
|
|
|
313
416
|
# for LM training
|
|
314
417
|
def __getitem__(
|
|
315
|
-
|
|
418
|
+
self,
|
|
419
|
+
node_id: IndexType,
|
|
316
420
|
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
|
|
317
421
|
r"""This function will override the function in
|
|
318
422
|
torch.utils.data.Dataset, and will be called when you
|
|
@@ -343,8 +447,8 @@ class TAGDataset(InMemoryDataset):
|
|
|
343
447
|
def __repr__(self) -> str:
|
|
344
448
|
return f'{self.__class__.__name__}()'
|
|
345
449
|
|
|
346
|
-
def to_text_dataset(self) -> TextDataset:
|
|
450
|
+
def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
|
|
347
451
|
r"""Factory Build text dataset from Text Attributed Graph Dataset
|
|
348
452
|
each data point is node's associated text token.
|
|
349
453
|
"""
|
|
350
|
-
return TAGDataset.TextDataset(self)
|
|
454
|
+
return TAGDataset.TextDataset(self, text_type)
|
torch_geometric/loader/utils.py
CHANGED
|
@@ -256,14 +256,6 @@ def filter_custom_hetero_store(
|
|
|
256
256
|
# Construct a new `HeteroData` object:
|
|
257
257
|
data = custom_cls() if custom_cls is not None else HeteroData()
|
|
258
258
|
|
|
259
|
-
# Filter edge storage:
|
|
260
|
-
# TODO support edge attributes
|
|
261
|
-
for attr in graph_store.get_all_edge_attrs():
|
|
262
|
-
key = attr.edge_type
|
|
263
|
-
if key in row_dict and key in col_dict:
|
|
264
|
-
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
|
|
265
|
-
data[attr.edge_type].edge_index = edge_index
|
|
266
|
-
|
|
267
259
|
# Filter node storage:
|
|
268
260
|
required_attrs = []
|
|
269
261
|
for attr in feature_store.get_all_tensor_attrs():
|
|
@@ -280,6 +272,14 @@ def filter_custom_hetero_store(
|
|
|
280
272
|
for i, attr in enumerate(required_attrs):
|
|
281
273
|
data[attr.group_name][attr.attr_name] = tensors[i]
|
|
282
274
|
|
|
275
|
+
# Filter edge storage:
|
|
276
|
+
# TODO support edge attributes
|
|
277
|
+
for attr in graph_store.get_all_edge_attrs():
|
|
278
|
+
key = attr.edge_type
|
|
279
|
+
if key in row_dict and key in col_dict:
|
|
280
|
+
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
|
|
281
|
+
data[attr.edge_type].edge_index = edge_index
|
|
282
|
+
|
|
283
283
|
return data
|
|
284
284
|
|
|
285
285
|
|
|
@@ -53,7 +53,7 @@ class LinkPredMetricData:
|
|
|
53
53
|
|
|
54
54
|
# Flatten both prediction and ground-truth indices, and determine
|
|
55
55
|
# overlaps afterwards via `torch.searchsorted`.
|
|
56
|
-
max_index = max(
|
|
56
|
+
max_index = max(
|
|
57
57
|
self.pred_index_mat.max()
|
|
58
58
|
if self.pred_index_mat.numel() > 0 else 0,
|
|
59
59
|
self.edge_label_index[1].max()
|
|
@@ -820,9 +820,10 @@ class LinkPredPersonalization(_LinkPredMetric):
|
|
|
820
820
|
right = pred[col.cpu()].to(device)
|
|
821
821
|
|
|
822
822
|
# Use offset to work around applying `isin` along a specific dim:
|
|
823
|
-
i = max(left.max(), right.max()) + 1
|
|
824
|
-
|
|
825
|
-
|
|
823
|
+
i = max(int(left.max()), int(right.max())) + 1
|
|
824
|
+
idx = torch.arange(0, i * row.size(0), i, device=device)
|
|
825
|
+
idx = idx.view(-1, 1)
|
|
826
|
+
isin = torch.isin(left + idx, right + idx)
|
|
826
827
|
|
|
827
828
|
# Compute personalization via average inverse cosine similarity:
|
|
828
829
|
cos = isin.sum(dim=-1) / pred.size(1)
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250918.dist-info → pyg_nightly-2.7.0.dev20250920.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|