pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,350 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Dict, List, Optional, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from torch import Tensor
|
9
|
+
from tqdm import tqdm
|
10
|
+
|
11
|
+
from torch_geometric.data import InMemoryDataset, download_google_url
|
12
|
+
from torch_geometric.data.data import BaseData
|
13
|
+
|
14
|
+
try:
|
15
|
+
from pandas import DataFrame, read_csv
|
16
|
+
WITH_PANDAS = True
|
17
|
+
except ImportError:
|
18
|
+
WITH_PANDAS = False
|
19
|
+
|
20
|
+
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
|
21
|
+
|
22
|
+
|
23
|
+
class TAGDataset(InMemoryDataset):
|
24
|
+
r"""The Text Attributed Graph datasets from the
|
25
|
+
`"Learning on Large-scale Text-attributed Graphs via Variational Inference
|
26
|
+
" <https://arxiv.org/abs/2210.14709>`_ paper.
|
27
|
+
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
|
28
|
+
into Text Attributed Graph that each node in graph is associate with a
|
29
|
+
raw text, that dataset can be adapt to DataLoader (for LM training) and
|
30
|
+
NeighborLoader(for GNN training). In addition, this class can be use as a
|
31
|
+
wrapper class by convert a InMemoryDataset with Tokenizer and text into
|
32
|
+
Text Attributed Graph.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
root (str): Root directory where the dataset should be saved.
|
36
|
+
dataset (InMemoryDataset): The name of the dataset
|
37
|
+
(:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
|
38
|
+
tokenizer_name (str): The tokenizer name for language model,
|
39
|
+
Be sure to use same tokenizer name as your `model id` of model repo
|
40
|
+
on huggingface.co.
|
41
|
+
text (List[str]): list of raw text associate with node, the order of
|
42
|
+
list should be align with node list
|
43
|
+
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
|
44
|
+
for saving split index, it is required that if your dataset doesn't
|
45
|
+
have get_split_idx function
|
46
|
+
tokenize_batch_size (int): batch size of tokenizing text, the
|
47
|
+
tokenizing process will run on cpu, default: 256
|
48
|
+
token_on_disk (bool): save token as .pt file on disk or not,
|
49
|
+
default: False
|
50
|
+
text_on_disk (bool): save given text(list of str) as dataframe on disk
|
51
|
+
or not, default: False
|
52
|
+
force_reload (bool): default: False
|
53
|
+
.. note::
|
54
|
+
See `example/llm_plus_gnn/glem.py` for example usage
|
55
|
+
"""
|
56
|
+
raw_text_id = {
|
57
|
+
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
|
58
|
+
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
|
59
|
+
}
|
60
|
+
|
61
|
+
def __init__(self, root: str, dataset: InMemoryDataset,
|
62
|
+
tokenizer_name: str, text: Optional[List[str]] = None,
|
63
|
+
split_idx: Optional[Dict[str, Tensor]] = None,
|
64
|
+
tokenize_batch_size: int = 256, token_on_disk: bool = False,
|
65
|
+
text_on_disk: bool = False,
|
66
|
+
force_reload: bool = False) -> None:
|
67
|
+
# list the vars you want to pass in before run download & process
|
68
|
+
self.name = dataset.name
|
69
|
+
self.text = text
|
70
|
+
self.tokenizer_name = tokenizer_name
|
71
|
+
from transformers import AutoTokenizer
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
73
|
+
if self.tokenizer.pad_token_id is None:
|
74
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
75
|
+
if self.tokenizer.pad_token is None:
|
76
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
77
|
+
|
78
|
+
self.dir_name = '_'.join(dataset.name.split('-'))
|
79
|
+
self.root = osp.join(root, self.dir_name)
|
80
|
+
missing_str_list = []
|
81
|
+
if not WITH_PANDAS:
|
82
|
+
missing_str_list.append('pandas')
|
83
|
+
if len(missing_str_list) > 0:
|
84
|
+
missing_str = ' '.join(missing_str_list)
|
85
|
+
error_out = f"`pip install {missing_str}` to use this dataset."
|
86
|
+
raise ImportError(error_out)
|
87
|
+
if hasattr(dataset, 'get_idx_split'):
|
88
|
+
self.split_idx = dataset.get_idx_split()
|
89
|
+
elif split_idx is not None:
|
90
|
+
self.split_idx = split_idx
|
91
|
+
else:
|
92
|
+
raise ValueError("TAGDataset need split idx for generating "
|
93
|
+
"is_gold mask, please pass splited index "
|
94
|
+
"in format of dictionaty with 'train', 'valid' "
|
95
|
+
"'test' index tensor to 'split_idx'")
|
96
|
+
if text is not None and text_on_disk:
|
97
|
+
self.save_node_text(text)
|
98
|
+
self.text_on_disk = text_on_disk
|
99
|
+
# init will call download and process
|
100
|
+
super().__init__(self.root, transform=None, pre_transform=None,
|
101
|
+
pre_filter=None, force_reload=force_reload)
|
102
|
+
# after processing and download
|
103
|
+
# Dataset has to have BaseData as _data
|
104
|
+
assert dataset._data is not None
|
105
|
+
self._data = dataset._data # reassign reference
|
106
|
+
assert self._data is not None
|
107
|
+
assert dataset._data.y is not None
|
108
|
+
assert isinstance(self._data, BaseData)
|
109
|
+
assert self._data.num_nodes is not None
|
110
|
+
assert isinstance(dataset._data.num_nodes, int)
|
111
|
+
assert isinstance(self._data.num_nodes, int)
|
112
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
113
|
+
is_good_tensor = self.load_gold_mask()
|
114
|
+
self._is_gold = is_good_tensor.squeeze()
|
115
|
+
self._data['is_gold'] = is_good_tensor
|
116
|
+
if self.text is not None and len(self.text) != self._data.num_nodes:
|
117
|
+
raise ValueError("The number of text sequence in 'text' should be "
|
118
|
+
"equal to number of nodes!")
|
119
|
+
self.token_on_disk = token_on_disk
|
120
|
+
self.tokenize_batch_size = tokenize_batch_size
|
121
|
+
self._token = self.tokenize_graph(self.tokenize_batch_size)
|
122
|
+
self.__num_classes__ = dataset.num_classes
|
123
|
+
|
124
|
+
@property
|
125
|
+
def num_classes(self) -> int:
|
126
|
+
return self.__num_classes__
|
127
|
+
|
128
|
+
@property
|
129
|
+
def raw_file_names(self) -> List[str]:
|
130
|
+
file_names = []
|
131
|
+
for root, _, files in os.walk(osp.join(self.root, 'raw')):
|
132
|
+
for file in files:
|
133
|
+
file_names.append(file)
|
134
|
+
return file_names
|
135
|
+
|
136
|
+
@property
|
137
|
+
def processed_file_names(self) -> List[str]:
|
138
|
+
return [
|
139
|
+
'geometric_data_processed.pt', 'pre_filter.pt',
|
140
|
+
'pre_transformed.pt'
|
141
|
+
]
|
142
|
+
|
143
|
+
@property
|
144
|
+
def token(self) -> Dict[str, Tensor]:
|
145
|
+
if self._token is None: # lazy load
|
146
|
+
self._token = self.tokenize_graph()
|
147
|
+
return self._token
|
148
|
+
|
149
|
+
# load is_gold after init
|
150
|
+
@property
|
151
|
+
def is_gold(self) -> Tensor:
|
152
|
+
if self._is_gold is None:
|
153
|
+
print('lazy load is_gold!!')
|
154
|
+
self._is_gold = self.load_gold_mask()
|
155
|
+
return self._is_gold
|
156
|
+
|
157
|
+
def get_n_id(self, node_idx: IndexType) -> Tensor:
|
158
|
+
if self._n_id is None:
|
159
|
+
assert self._data is not None
|
160
|
+
assert self._data.num_nodes is not None
|
161
|
+
assert isinstance(self._data.num_nodes, int)
|
162
|
+
self._n_id = torch.arange(self._data.num_nodes)
|
163
|
+
return self._n_id[node_idx]
|
164
|
+
|
165
|
+
def load_gold_mask(self) -> Tensor:
|
166
|
+
r"""Use original train split as gold split, generating is_gold mask
|
167
|
+
for picking ground truth labels and pseudo labels.
|
168
|
+
"""
|
169
|
+
train_split_idx = self.get_idx_split()['train']
|
170
|
+
assert self._data is not None
|
171
|
+
assert self._data.num_nodes is not None
|
172
|
+
assert isinstance(self._data.num_nodes, int)
|
173
|
+
is_good_tensor = torch.zeros(self._data.num_nodes,
|
174
|
+
dtype=torch.bool).view(-1, 1)
|
175
|
+
is_good_tensor[train_split_idx] = True
|
176
|
+
return is_good_tensor
|
177
|
+
|
178
|
+
def get_gold(self, node_idx: IndexType) -> Tensor:
|
179
|
+
r"""Get gold mask for given node_idx.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
node_idx (torch.tensor): a tensor contain node idx
|
183
|
+
"""
|
184
|
+
if self._is_gold is None:
|
185
|
+
self._is_gold = self.is_gold
|
186
|
+
return self._is_gold[node_idx]
|
187
|
+
|
188
|
+
def get_idx_split(self) -> Dict[str, Tensor]:
|
189
|
+
return self.split_idx
|
190
|
+
|
191
|
+
def download(self) -> None:
|
192
|
+
print('downloading raw text')
|
193
|
+
raw_text_path = download_google_url(id=self.raw_text_id[self.name],
|
194
|
+
folder=f'{self.root}/raw',
|
195
|
+
filename='node-text.csv.gz',
|
196
|
+
log=True)
|
197
|
+
text_df = read_csv(raw_text_path)
|
198
|
+
self.text = list(text_df['text'])
|
199
|
+
|
200
|
+
def process(self) -> None:
|
201
|
+
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
|
202
|
+
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
|
203
|
+
self.text = list(text_df['text'])
|
204
|
+
elif self.name in self.raw_text_id:
|
205
|
+
self.download()
|
206
|
+
else:
|
207
|
+
print('The dataset is not ogbn-products nor ogbn-arxiv,'
|
208
|
+
'please pass in your raw text string list to `text`')
|
209
|
+
if self.text is None:
|
210
|
+
raise ValueError("The TAGDataset only have ogbn-products and "
|
211
|
+
"ogbn-arxiv raw text in default "
|
212
|
+
"The raw text of each node is not specified"
|
213
|
+
"Please pass in 'text' when convert your dataset "
|
214
|
+
"to Text Attribute Graph Dataset")
|
215
|
+
|
216
|
+
def save_node_text(self, text: List[str]) -> None:
|
217
|
+
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
|
218
|
+
if osp.exists(node_text_path):
|
219
|
+
print(f'The raw text is existed at {node_text_path}')
|
220
|
+
else:
|
221
|
+
print(f'Saving raw text file at {node_text_path}')
|
222
|
+
os.makedirs(f'{self.root}/raw', exist_ok=True)
|
223
|
+
text_df = DataFrame(text, columns=['text'])
|
224
|
+
text_df.to_csv(osp.join(node_text_path), compression='gzip',
|
225
|
+
index=False)
|
226
|
+
|
227
|
+
def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
|
228
|
+
r"""Tokenizing the text associate with each node, running in cpu.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
batch_size (Optional[int]): batch size of list of text for
|
232
|
+
generating emebdding
|
233
|
+
Returns:
|
234
|
+
Dict[str, torch.Tensor]: tokenized graph
|
235
|
+
"""
|
236
|
+
data_len = 0
|
237
|
+
if self.text is not None:
|
238
|
+
data_len = len(self.text)
|
239
|
+
else:
|
240
|
+
raise ValueError("The TAGDataset need text for tokenization")
|
241
|
+
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
242
|
+
path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
|
243
|
+
# Check if the .pt files already exist
|
244
|
+
token_files_exist = any(
|
245
|
+
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
|
246
|
+
|
247
|
+
if token_files_exist and self.token_on_disk:
|
248
|
+
print('Found tokenized file, loading may take several minutes...')
|
249
|
+
all_encoded_token = {
|
250
|
+
k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
|
251
|
+
for k in token_keys
|
252
|
+
if os.path.exists(os.path.join(path, f'{k}.pt'))
|
253
|
+
}
|
254
|
+
return all_encoded_token
|
255
|
+
|
256
|
+
all_encoded_token = {k: [] for k in token_keys}
|
257
|
+
pbar = tqdm(total=data_len)
|
258
|
+
|
259
|
+
pbar.set_description('Tokenizing Text Attributed Graph')
|
260
|
+
for i in range(0, data_len, batch_size):
|
261
|
+
end_index = min(data_len, i + batch_size)
|
262
|
+
token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
|
263
|
+
padding='max_length', truncation=True,
|
264
|
+
max_length=512, return_tensors="pt")
|
265
|
+
for k in token.keys():
|
266
|
+
all_encoded_token[k].append(token[k])
|
267
|
+
pbar.update(end_index - i)
|
268
|
+
pbar.close()
|
269
|
+
|
270
|
+
all_encoded_token = {
|
271
|
+
k: torch.cat(v)
|
272
|
+
for k, v in all_encoded_token.items() if len(v) > 0
|
273
|
+
}
|
274
|
+
if self.token_on_disk:
|
275
|
+
os.makedirs(path, exist_ok=True)
|
276
|
+
print('Saving tokens on Disk')
|
277
|
+
for k, tensor in all_encoded_token.items():
|
278
|
+
torch.save(tensor, os.path.join(path, f'{k}.pt'))
|
279
|
+
print('Token saved:', os.path.join(path, f'{k}.pt'))
|
280
|
+
os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning
|
281
|
+
return all_encoded_token
|
282
|
+
|
283
|
+
def __repr__(self) -> str:
|
284
|
+
return f'{self.__class__.__name__}()'
|
285
|
+
|
286
|
+
class TextDataset(torch.utils.data.Dataset):
|
287
|
+
r"""This nested dataset provides textual data for each node in
|
288
|
+
the graph. Factory method to create TextDataset from TAGDataset.
|
289
|
+
|
290
|
+
Args:
|
291
|
+
tag_dataset (TAGDataset): the parent dataset
|
292
|
+
"""
|
293
|
+
def __init__(self, tag_dataset: 'TAGDataset') -> None:
|
294
|
+
self.tag_dataset = tag_dataset
|
295
|
+
self.token = tag_dataset.token
|
296
|
+
assert tag_dataset._data is not None
|
297
|
+
self._data = tag_dataset._data
|
298
|
+
|
299
|
+
assert tag_dataset._data.y is not None
|
300
|
+
self.labels = tag_dataset._data.y
|
301
|
+
|
302
|
+
def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
|
303
|
+
r"""This function will be called in __getitem__().
|
304
|
+
|
305
|
+
Args:
|
306
|
+
node_idx (IndexType): selected node idx in each batch
|
307
|
+
Returns:
|
308
|
+
items (Dict[str, Tensor]): input for LM
|
309
|
+
"""
|
310
|
+
items = {k: v[node_idx] for k, v in self.token.items()}
|
311
|
+
return items
|
312
|
+
|
313
|
+
# for LM training
|
314
|
+
def __getitem__(
|
315
|
+
self, node_id: IndexType
|
316
|
+
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
|
317
|
+
r"""This function will override the function in
|
318
|
+
torch.utils.data.Dataset, and will be called when you
|
319
|
+
iterate batch in the dataloader, make sure all following
|
320
|
+
key value pairs are present in the return dict.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
node_id (List[int]): list of node idx for selecting tokens,
|
324
|
+
labels etc. when iterating data loader for LM
|
325
|
+
Returns:
|
326
|
+
items (dict): input k,v pairs for Language model training and
|
327
|
+
inference
|
328
|
+
"""
|
329
|
+
item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
|
330
|
+
item['input'] = self.get_token(node_id)
|
331
|
+
item['labels'] = self.labels[node_id]
|
332
|
+
item['is_gold'] = self.tag_dataset.get_gold(node_id)
|
333
|
+
item['n_id'] = self.tag_dataset.get_n_id(node_id)
|
334
|
+
return item
|
335
|
+
|
336
|
+
def __len__(self) -> int:
|
337
|
+
assert self._data.num_nodes is not None
|
338
|
+
return self._data.num_nodes
|
339
|
+
|
340
|
+
def get(self, idx: int) -> BaseData:
|
341
|
+
return self._data
|
342
|
+
|
343
|
+
def __repr__(self) -> str:
|
344
|
+
return f'{self.__class__.__name__}()'
|
345
|
+
|
346
|
+
def to_text_dataset(self) -> TextDataset:
|
347
|
+
r"""Factory Build text dataset from Text Attributed Graph Dataset
|
348
|
+
each data point is node's associated text token.
|
349
|
+
"""
|
350
|
+
return TAGDataset.TextDataset(self)
|
torch_geometric/datasets/upfd.py
CHANGED
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import (
|
@@ -130,6 +129,8 @@ class UPFD(InMemoryDataset):
|
|
130
129
|
os.remove(path)
|
131
130
|
|
132
131
|
def process(self) -> None:
|
132
|
+
import scipy.sparse as sp
|
133
|
+
|
133
134
|
x = sp.load_npz(
|
134
135
|
osp.join(self.raw_dir, f'new_{self.feature}_feature.npz'))
|
135
136
|
x = torch.from_numpy(x.todense()).to(torch.float)
|
@@ -0,0 +1,246 @@
|
|
1
|
+
# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630
|
2
|
+
from typing import Any, Dict, List, Tuple, no_type_check
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from torch import Tensor
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
from torch_geometric.data import Data, InMemoryDataset
|
10
|
+
from torch_geometric.nn.nlp import SentenceTransformer
|
11
|
+
|
12
|
+
|
13
|
+
@no_type_check
|
14
|
+
def retrieval_via_pcst(
|
15
|
+
data: Data,
|
16
|
+
q_emb: Tensor,
|
17
|
+
textual_nodes: Any,
|
18
|
+
textual_edges: Any,
|
19
|
+
topk: int = 3,
|
20
|
+
topk_e: int = 3,
|
21
|
+
cost_e: float = 0.5,
|
22
|
+
) -> Tuple[Data, str]:
|
23
|
+
c = 0.01
|
24
|
+
|
25
|
+
from pcst_fast import pcst_fast
|
26
|
+
|
27
|
+
root = -1
|
28
|
+
num_clusters = 1
|
29
|
+
pruning = 'gw'
|
30
|
+
verbosity_level = 0
|
31
|
+
if topk > 0:
|
32
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
33
|
+
topk = min(topk, data.num_nodes)
|
34
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
35
|
+
|
36
|
+
n_prizes = torch.zeros_like(n_prizes)
|
37
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
38
|
+
else:
|
39
|
+
n_prizes = torch.zeros(data.num_nodes)
|
40
|
+
|
41
|
+
if topk_e > 0:
|
42
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
43
|
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
44
|
+
|
45
|
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
46
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
47
|
+
last_topk_e_value = topk_e
|
48
|
+
for k in range(topk_e):
|
49
|
+
indices = e_prizes == topk_e_values[k]
|
50
|
+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
51
|
+
e_prizes[indices] = value
|
52
|
+
last_topk_e_value = value * (1 - c)
|
53
|
+
# reduce the cost of the edges such that at least one edge is selected
|
54
|
+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
55
|
+
else:
|
56
|
+
e_prizes = torch.zeros(data.num_edges)
|
57
|
+
|
58
|
+
costs = []
|
59
|
+
edges = []
|
60
|
+
virtual_n_prizes = []
|
61
|
+
virtual_edges = []
|
62
|
+
virtual_costs = []
|
63
|
+
mapping_n = {}
|
64
|
+
mapping_e = {}
|
65
|
+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
66
|
+
prize_e = e_prizes[i]
|
67
|
+
if prize_e <= cost_e:
|
68
|
+
mapping_e[len(edges)] = i
|
69
|
+
edges.append((src, dst))
|
70
|
+
costs.append(cost_e - prize_e)
|
71
|
+
else:
|
72
|
+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
73
|
+
mapping_n[virtual_node_id] = i
|
74
|
+
virtual_edges.append((src, virtual_node_id))
|
75
|
+
virtual_edges.append((virtual_node_id, dst))
|
76
|
+
virtual_costs.append(0)
|
77
|
+
virtual_costs.append(0)
|
78
|
+
virtual_n_prizes.append(prize_e - cost_e)
|
79
|
+
|
80
|
+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
81
|
+
num_edges = len(edges)
|
82
|
+
if len(virtual_costs) > 0:
|
83
|
+
costs = np.array(costs + virtual_costs)
|
84
|
+
edges = np.array(edges + virtual_edges)
|
85
|
+
|
86
|
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
87
|
+
pruning, verbosity_level)
|
88
|
+
|
89
|
+
selected_nodes = vertices[vertices < data.num_nodes]
|
90
|
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
91
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
92
|
+
if len(virtual_vertices) > 0:
|
93
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
94
|
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
95
|
+
selected_edges = np.array(selected_edges + virtual_edges)
|
96
|
+
|
97
|
+
edge_index = data.edge_index[:, selected_edges]
|
98
|
+
selected_nodes = np.unique(
|
99
|
+
np.concatenate(
|
100
|
+
[selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
|
101
|
+
|
102
|
+
n = textual_nodes.iloc[selected_nodes]
|
103
|
+
e = textual_edges.iloc[selected_edges]
|
104
|
+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
105
|
+
index=False, columns=['src', 'edge_attr', 'dst'])
|
106
|
+
|
107
|
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
108
|
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
109
|
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
110
|
+
|
111
|
+
data = Data(
|
112
|
+
x=data.x[selected_nodes],
|
113
|
+
edge_index=torch.tensor([src, dst]),
|
114
|
+
edge_attr=data.edge_attr[selected_edges],
|
115
|
+
)
|
116
|
+
|
117
|
+
return data, desc
|
118
|
+
|
119
|
+
|
120
|
+
class WebQSPDataset(InMemoryDataset):
|
121
|
+
r"""The WebQuestionsSP dataset of the `"The Value of Semantic Parse
|
122
|
+
Labeling for Knowledge Base Question Answering"
|
123
|
+
<https://aclanthology.org/P16-2033/>`_ paper.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
root (str): Root directory where the dataset should be saved.
|
127
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
128
|
+
If :obj:`"val"`, loads the validation dataset.
|
129
|
+
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
|
130
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
131
|
+
(default: :obj:`False`)
|
132
|
+
use_pcst (bool, optional): Whether to preprocess the dataset's graph
|
133
|
+
with PCST or return the full graphs. (default: :obj:`True`)
|
134
|
+
"""
|
135
|
+
def __init__(
|
136
|
+
self,
|
137
|
+
root: str,
|
138
|
+
split: str = "train",
|
139
|
+
force_reload: bool = False,
|
140
|
+
use_pcst: bool = True,
|
141
|
+
) -> None:
|
142
|
+
self.use_pcst = use_pcst
|
143
|
+
super().__init__(root, force_reload=force_reload)
|
144
|
+
|
145
|
+
if split not in {'train', 'val', 'test'}:
|
146
|
+
raise ValueError(f"Invalid 'split' argument (got {split})")
|
147
|
+
|
148
|
+
path = self.processed_paths[['train', 'val', 'test'].index(split)]
|
149
|
+
self.load(path)
|
150
|
+
|
151
|
+
@property
|
152
|
+
def processed_file_names(self) -> List[str]:
|
153
|
+
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
|
154
|
+
|
155
|
+
def process(self) -> None:
|
156
|
+
import datasets
|
157
|
+
import pandas as pd
|
158
|
+
|
159
|
+
datasets = datasets.load_dataset('rmanluo/RoG-webqsp')
|
160
|
+
|
161
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
162
|
+
model_name = 'sentence-transformers/all-roberta-large-v1'
|
163
|
+
model = SentenceTransformer(model_name).to(device)
|
164
|
+
model.eval()
|
165
|
+
|
166
|
+
for dataset, path in zip(
|
167
|
+
[datasets['train'], datasets['validation'], datasets['test']],
|
168
|
+
self.processed_paths,
|
169
|
+
):
|
170
|
+
questions = [example["question"] for example in dataset]
|
171
|
+
question_embs = model.encode(
|
172
|
+
questions,
|
173
|
+
batch_size=256,
|
174
|
+
output_device='cpu',
|
175
|
+
)
|
176
|
+
|
177
|
+
data_list = []
|
178
|
+
for i, example in enumerate(tqdm(dataset)):
|
179
|
+
raw_nodes: Dict[str, int] = {}
|
180
|
+
raw_edges = []
|
181
|
+
for tri in example["graph"]:
|
182
|
+
h, r, t = tri
|
183
|
+
h = h.lower()
|
184
|
+
t = t.lower()
|
185
|
+
if h not in raw_nodes:
|
186
|
+
raw_nodes[h] = len(raw_nodes)
|
187
|
+
if t not in raw_nodes:
|
188
|
+
raw_nodes[t] = len(raw_nodes)
|
189
|
+
raw_edges.append({
|
190
|
+
"src": raw_nodes[h],
|
191
|
+
"edge_attr": r,
|
192
|
+
"dst": raw_nodes[t]
|
193
|
+
})
|
194
|
+
nodes = pd.DataFrame([{
|
195
|
+
"node_id": v,
|
196
|
+
"node_attr": k,
|
197
|
+
} for k, v in raw_nodes.items()],
|
198
|
+
columns=["node_id", "node_attr"])
|
199
|
+
edges = pd.DataFrame(raw_edges,
|
200
|
+
columns=["src", "edge_attr", "dst"])
|
201
|
+
|
202
|
+
nodes.node_attr = nodes.node_attr.fillna("")
|
203
|
+
x = model.encode(
|
204
|
+
nodes.node_attr.tolist(),
|
205
|
+
batch_size=256,
|
206
|
+
output_device='cpu',
|
207
|
+
)
|
208
|
+
edge_attr = model.encode(
|
209
|
+
edges.edge_attr.tolist(),
|
210
|
+
batch_size=256,
|
211
|
+
output_device='cpu',
|
212
|
+
)
|
213
|
+
edge_index = torch.tensor([
|
214
|
+
edges.src.tolist(),
|
215
|
+
edges.dst.tolist(),
|
216
|
+
], dtype=torch.long)
|
217
|
+
|
218
|
+
question = f"Question: {example['question']}\nAnswer: "
|
219
|
+
label = ('|').join(example['answer']).lower()
|
220
|
+
data = Data(
|
221
|
+
x=x,
|
222
|
+
edge_index=edge_index,
|
223
|
+
edge_attr=edge_attr,
|
224
|
+
)
|
225
|
+
if self.use_pcst and len(nodes) > 0 and len(edges) > 0:
|
226
|
+
data, desc = retrieval_via_pcst(
|
227
|
+
data,
|
228
|
+
question_embs[i],
|
229
|
+
nodes,
|
230
|
+
edges,
|
231
|
+
topk=3,
|
232
|
+
topk_e=5,
|
233
|
+
cost_e=0.5,
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
desc = nodes.to_csv(index=False) + "\n" + edges.to_csv(
|
237
|
+
index=False,
|
238
|
+
columns=["src", "edge_attr", "dst"],
|
239
|
+
)
|
240
|
+
|
241
|
+
data.question = question
|
242
|
+
data.label = label
|
243
|
+
data.desc = desc
|
244
|
+
data_list.append(data)
|
245
|
+
|
246
|
+
self.save(data_list, path)
|
@@ -102,7 +102,7 @@ class WebKB(InMemoryDataset):
|
|
102
102
|
download_url(f'{self.url}/splits/{f}', self.raw_dir)
|
103
103
|
|
104
104
|
def process(self) -> None:
|
105
|
-
with open(self.raw_paths[0]
|
105
|
+
with open(self.raw_paths[0]) as f:
|
106
106
|
lines = f.read().split('\n')[1:-1]
|
107
107
|
xs = [[float(value) for value in line.split('\t')[1].split(',')]
|
108
108
|
for line in lines]
|
@@ -111,7 +111,7 @@ class WebKB(InMemoryDataset):
|
|
111
111
|
ys = [int(line.split('\t')[2]) for line in lines]
|
112
112
|
y = torch.tensor(ys, dtype=torch.long)
|
113
113
|
|
114
|
-
with open(self.raw_paths[1]
|
114
|
+
with open(self.raw_paths[1]) as f:
|
115
115
|
lines = f.read().split('\n')[1:-1]
|
116
116
|
edge_indices = [[int(value) for value in line.split('\t')]
|
117
117
|
for line in lines]
|
@@ -65,7 +65,7 @@ class WikiCS(InMemoryDataset):
|
|
65
65
|
download_url(f'{self.url}/{name}', self.raw_dir)
|
66
66
|
|
67
67
|
def process(self) -> None:
|
68
|
-
with open(self.raw_paths[0]
|
68
|
+
with open(self.raw_paths[0]) as f:
|
69
69
|
data = json.load(f)
|
70
70
|
|
71
71
|
x = torch.tensor(data['features'], dtype=torch.float)
|
@@ -10,6 +10,7 @@ from torch_geometric.data import (
|
|
10
10
|
download_url,
|
11
11
|
extract_tar,
|
12
12
|
)
|
13
|
+
from torch_geometric.io import fs
|
13
14
|
|
14
15
|
|
15
16
|
class Wikidata5M(InMemoryDataset):
|
@@ -99,7 +100,7 @@ class Wikidata5M(InMemoryDataset):
|
|
99
100
|
values = line.strip().split('\t')
|
100
101
|
entity_to_id[values[0]] = i
|
101
102
|
|
102
|
-
x =
|
103
|
+
x = fs.torch_load(self.raw_paths[1])
|
103
104
|
|
104
105
|
edge_indices = []
|
105
106
|
edge_types = []
|
@@ -107,7 +108,7 @@ class Wikidata5M(InMemoryDataset):
|
|
107
108
|
|
108
109
|
rel_to_id: Dict[str, int] = {}
|
109
110
|
for split, path in enumerate(self.raw_paths[2:]):
|
110
|
-
with open(path
|
111
|
+
with open(path) as f:
|
111
112
|
for line in f:
|
112
113
|
head, rel, tail = line[:-1].split('\t')
|
113
114
|
src = entity_to_id[head]
|
@@ -105,7 +105,7 @@ class WikipediaNetwork(InMemoryDataset):
|
|
105
105
|
|
106
106
|
def process(self) -> None:
|
107
107
|
if self.geom_gcn_preprocess:
|
108
|
-
with open(self.raw_paths[0]
|
108
|
+
with open(self.raw_paths[0]) as f:
|
109
109
|
lines = f.read().split('\n')[1:-1]
|
110
110
|
xs = [[float(value) for value in line.split('\t')[1].split(',')]
|
111
111
|
for line in lines]
|
@@ -113,7 +113,7 @@ class WikipediaNetwork(InMemoryDataset):
|
|
113
113
|
ys = [int(line.split('\t')[2]) for line in lines]
|
114
114
|
y = torch.tensor(ys, dtype=torch.long)
|
115
115
|
|
116
|
-
with open(self.raw_paths[1]
|
116
|
+
with open(self.raw_paths[1]) as f:
|
117
117
|
lines = f.read().split('\n')[1:-1]
|
118
118
|
edge_indices = [[int(value) for value in line.split('\t')]
|
119
119
|
for line in lines]
|