pyg-nightly 2.7.0.dev20250919__py3-none-any.whl → 2.7.0.dev20250920__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250919
3
+ Version: 2.7.0.dev20250920
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=OLGPhTHC1wmAq6rg69s1sbJUFuptQaVPvG0ggAYNOlM,2292
1
+ torch_geometric/__init__.py,sha256=0ij8VxVSK4T5A19Dr05CrBi0vbfTv2d0vTpB73hQsws,2292
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=ODB_8cwFUiwBUjngXn6-K5HHb7IDul7DDXuuGX7vj_0,8178
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -151,7 +151,7 @@ torch_geometric/datasets/shapenet.py,sha256=tn3HiQQAr6lxHrqxfOVaAtl40guwFYTXWCbS
151
151
  torch_geometric/datasets/shrec2016.py,sha256=cTLhctbqE0EUEvKddJFhPzDb1oLKXOth4O_WzsWtyMk,6323
152
152
  torch_geometric/datasets/snap_dataset.py,sha256=deJvB6cpIQ3bu_pcWoqgEo1-Kl_NcFi7ZSUci645X0U,9481
153
153
  torch_geometric/datasets/suite_sparse.py,sha256=eqjH4vAUq872qdk3YdLkZSwlu6r7HHpTgK0vEVGmY1s,3278
154
- torch_geometric/datasets/tag_dataset.py,sha256=qTnwr2N1tbWYeLGbItfv70UxQ3n1rKesjeVU3kcOCP8,14757
154
+ torch_geometric/datasets/tag_dataset.py,sha256=jslijGCh37ip2YkrQLyvbk-1QRJ3yqFpmzuQSxckXrE,19402
155
155
  torch_geometric/datasets/taobao.py,sha256=CUcZpbWsNTasevflO8zqP0YvENy89P7wpKS4MHaDJ6Q,4170
156
156
  torch_geometric/datasets/teeth3ds.py,sha256=hZvhcq9lsQENNFr5hk50w2T3CgxE_tlnQfrCgN6uIDQ,9919
157
157
  torch_geometric/datasets/tosca.py,sha256=nUSF8NQT1GlkwWQLshjWmr8xORsvRHzzIqhUyDCvABc,4632
@@ -308,10 +308,10 @@ torch_geometric/loader/prefetch.py,sha256=z30TIcu3_6ZubllUOwNLunlq4RyQdFj36vPE5Q
308
308
  torch_geometric/loader/random_node_loader.py,sha256=rCmRXYv70SPxBo-Oh049eFEWEZDV7FmlRPzmjcoirXQ,2196
309
309
  torch_geometric/loader/shadow.py,sha256=_hCspYf9SlJYX0lqEjxFec9e9t1iMScNThOoWR1wQGM,4173
310
310
  torch_geometric/loader/temporal_dataloader.py,sha256=Z7L_rYdl6SYBQXAgtr18FVcmfMH9kP1fBWrc2W63g2c,2250
311
- torch_geometric/loader/utils.py,sha256=3hzKzIgB52QIZu7Jdn4JeXZaegIJinIQfIUP9DrUWUQ,14903
311
+ torch_geometric/loader/utils.py,sha256=DgGHK6kNu7ZZIZuaT0Ya_4rUctBMMKyBBSdHhuU389w,14903
312
312
  torch_geometric/loader/zip_loader.py,sha256=3lt10fD15Rxm1WhWzypswGzCEwUz4h8OLCD1nE15yNg,3843
313
313
  torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7z48e94,699
314
- torch_geometric/metrics/link_pred.py,sha256=1_hE3KiRqAdZLI6QuUbjgyFC__mTyFu_RimM3bD8wRw,31678
314
+ torch_geometric/metrics/link_pred.py,sha256=bacmFGn7rm0iF2wOJdAW-iTZ04bOuiS-7ur2K-MZKlA,31684
315
315
  torch_geometric/nn/__init__.py,sha256=tTEKDy4vpjPNKyG1Vg9GIx7dVFJuQtBoh2M19ascGpo,880
316
316
  torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
317
317
  torch_geometric/nn/encoding.py,sha256=82fpwyOx0-STFSAJ5AzG0p2WFC9u1M4KgmKIql8hSLc,3634
@@ -654,7 +654,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
654
654
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
655
655
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
656
656
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
657
- pyg_nightly-2.7.0.dev20250919.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
658
- pyg_nightly-2.7.0.dev20250919.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
659
- pyg_nightly-2.7.0.dev20250919.dist-info/METADATA,sha256=IfaNYkgI-HE5ar5wS1k5XG9esWAh643uI1uvCOX7ChY,64145
660
- pyg_nightly-2.7.0.dev20250919.dist-info/RECORD,,
657
+ pyg_nightly-2.7.0.dev20250920.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
658
+ pyg_nightly-2.7.0.dev20250920.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
659
+ pyg_nightly-2.7.0.dev20250920.dist-info/METADATA,sha256=PAeahjszlJpaI4WHs-eZPOYELiodtDDAPudxTK4MfTA,64145
660
+ pyg_nightly-2.7.0.dev20250920.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250919'
34
+ __version__ = '2.7.0.dev20250920'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,3 +1,4 @@
1
+ import csv
1
2
  import os
2
3
  import os.path as osp
3
4
  from collections.abc import Sequence
@@ -10,6 +11,7 @@ from tqdm import tqdm
10
11
 
11
12
  from torch_geometric.data import InMemoryDataset, download_google_url
12
13
  from torch_geometric.data.data import BaseData
14
+ from torch_geometric.io import fs
13
15
 
14
16
  try:
15
17
  from pandas import DataFrame, read_csv
@@ -22,14 +24,16 @@ IndexType = Union[slice, Tensor, np.ndarray, Sequence]
22
24
 
23
25
  class TAGDataset(InMemoryDataset):
24
26
  r"""The Text Attributed Graph datasets from the
25
- `"Learning on Large-scale Text-attributed Graphs via Variational Inference
26
- " <https://arxiv.org/abs/2210.14709>`_ paper.
27
+ `"Learning on Large-scale Text-attributed Graphs via Variational Inference"
28
+ <https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
29
+ LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
30
+ Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
27
31
  This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
28
32
  into Text Attributed Graph that each node in graph is associate with a
29
- raw text, that dataset can be adapt to DataLoader (for LM training) and
30
- NeighborLoader(for GNN training). In addition, this class can be use as a
31
- wrapper class by convert a InMemoryDataset with Tokenizer and text into
32
- Text Attributed Graph.
33
+ raw text, LLM prediction and explanation, that dataset can be adapt to
34
+ DataLoader (for LM training) and NeighborLoader(for GNN training).
35
+ In addition, this class can be use as a wrapper class by convert a
36
+ InMemoryDataset with Tokenizer and text into Text Attributed Graph.
33
37
 
34
38
  Args:
35
39
  root (str): Root directory where the dataset should be saved.
@@ -51,22 +55,35 @@ class TAGDataset(InMemoryDataset):
51
55
  or not, default: False
52
56
  force_reload (bool): default: False
53
57
  .. note::
54
- See `example/llm_plus_gnn/glem.py` for example usage
58
+ See `example/llm/glem.py` for example usage
55
59
  """
56
60
  raw_text_id = {
57
61
  'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
58
62
  'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
59
63
  }
60
64
 
61
- def __init__(self, root: str, dataset: InMemoryDataset,
62
- tokenizer_name: str, text: Optional[List[str]] = None,
63
- split_idx: Optional[Dict[str, Tensor]] = None,
64
- tokenize_batch_size: int = 256, token_on_disk: bool = False,
65
- text_on_disk: bool = False,
66
- force_reload: bool = False) -> None:
65
+ llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'
66
+
67
+ llm_explanation_id = {
68
+ 'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
69
+ }
70
+
71
+ def __init__(
72
+ self,
73
+ root: str,
74
+ dataset: InMemoryDataset,
75
+ tokenizer_name: str,
76
+ text: Optional[List[str]] = None,
77
+ split_idx: Optional[Dict[str, Tensor]] = None,
78
+ tokenize_batch_size: int = 256,
79
+ token_on_disk: bool = False,
80
+ text_on_disk: bool = False,
81
+ force_reload: bool = False,
82
+ ) -> None:
67
83
  # list the vars you want to pass in before run download & process
68
84
  self.name = dataset.name
69
85
  self.text = text
86
+ self.llm_prediction_topk = 5
70
87
  self.tokenizer_name = tokenizer_name
71
88
  from transformers import AutoTokenizer
72
89
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
@@ -93,8 +110,9 @@ class TAGDataset(InMemoryDataset):
93
110
  "is_gold mask, please pass splited index "
94
111
  "in format of dictionaty with 'train', 'valid' "
95
112
  "'test' index tensor to 'split_idx'")
96
- if text is not None and text_on_disk:
97
- self.save_node_text(text)
113
+ if text_on_disk:
114
+ if text is not None:
115
+ self.save_node_text(text)
98
116
  self.text_on_disk = text_on_disk
99
117
  # init will call download and process
100
118
  super().__init__(self.root, transform=None, pre_transform=None,
@@ -119,6 +137,10 @@ class TAGDataset(InMemoryDataset):
119
137
  self.token_on_disk = token_on_disk
120
138
  self.tokenize_batch_size = tokenize_batch_size
121
139
  self._token = self.tokenize_graph(self.tokenize_batch_size)
140
+ self._llm_explanation_token = self.tokenize_graph(
141
+ self.tokenize_batch_size, text_type='llm_explanation')
142
+ self._all_token = self.tokenize_graph(self.tokenize_batch_size,
143
+ text_type='all')
122
144
  self.__num_classes__ = dataset.num_classes
123
145
 
124
146
  @property
@@ -146,6 +168,19 @@ class TAGDataset(InMemoryDataset):
146
168
  self._token = self.tokenize_graph()
147
169
  return self._token
148
170
 
171
+ @property
172
+ def llm_explanation_token(self) -> Dict[str, Tensor]:
173
+ if self._llm_explanation_token is None: # lazy load
174
+ self._llm_explanation_token = self.tokenize_graph(
175
+ text_type='llm_explanation')
176
+ return self._llm_explanation_token
177
+
178
+ @property
179
+ def all_token(self) -> Dict[str, Tensor]:
180
+ if self._all_token is None: # lazy load
181
+ self._all_token = self.tokenize_graph(text_type='all')
182
+ return self._all_token
183
+
149
184
  # load is_gold after init
150
185
  @property
151
186
  def is_gold(self) -> Tensor:
@@ -194,10 +229,17 @@ class TAGDataset(InMemoryDataset):
194
229
  folder=f'{self.root}/raw',
195
230
  filename='node-text.csv.gz',
196
231
  log=True)
197
- text_df = read_csv(raw_text_path)
198
- self.text = list(text_df['text'])
232
+ self.text = list(read_csv(raw_text_path)['text'])
233
+ print('downloading llm explanations')
234
+ llm_explanation_path = download_google_url(
235
+ id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw',
236
+ filename='node-gpt-response.csv.gz', log=True)
237
+ self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
238
+ print('downloading llm predictions')
239
+ fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
199
240
 
200
241
  def process(self) -> None:
242
+ # process Title and Abstraction
201
243
  if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
202
244
  text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
203
245
  self.text = list(text_df['text'])
@@ -212,6 +254,42 @@ class TAGDataset(InMemoryDataset):
212
254
  "The raw text of each node is not specified"
213
255
  "Please pass in 'text' when convert your dataset "
214
256
  "to Text Attribute Graph Dataset")
257
+ # process LLM explanation and prediction
258
+ llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
259
+ llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
260
+ if osp.exists(llm_explanation_path) and osp.exists(
261
+ llm_prediction_path):
262
+ # load LLM explanation
263
+ self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
264
+ # load LLM prediction
265
+ preds = []
266
+ with open(llm_prediction_path) as file:
267
+ reader = csv.reader(file)
268
+ for row in reader:
269
+ inner_list = []
270
+ for value in row:
271
+ inner_list.append(int(value))
272
+ preds.append(inner_list)
273
+
274
+ pl = torch.zeros(len(preds), self.llm_prediction_topk,
275
+ dtype=torch.long)
276
+ for i, pred in enumerate(preds):
277
+ pl[i][:len(pred)] = torch.tensor(
278
+ pred[:self.llm_prediction_topk], dtype=torch.long) + 1
279
+ elif self.name in self.llm_explanation_id:
280
+ self.download()
281
+ else:
282
+ print(
283
+ 'The dataset is not ogbn-arxiv,'
284
+ 'please pass in your llm explanation list to `llm_explanation`'
285
+ 'and llm prediction list to `llm_prediction`')
286
+ if self.llm_explanation is None or pl is None:
287
+ raise ValueError(
288
+ "The TAGDataset only have ogbn-arxiv LLM explanations"
289
+ "and predictions in default. The llm explanation and"
290
+ "prediction of each node is not specified."
291
+ "Please pass in 'llm_explanation' and 'llm_prediction' when"
292
+ "convert your dataset to Text Attribute Graph Dataset")
215
293
 
216
294
  def save_node_text(self, text: List[str]) -> None:
217
295
  node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
@@ -224,22 +302,39 @@ class TAGDataset(InMemoryDataset):
224
302
  text_df.to_csv(osp.join(node_text_path), compression='gzip',
225
303
  index=False)
226
304
 
227
- def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
305
+ def tokenize_graph(self, batch_size: int = 256,
306
+ text_type: str = 'raw_text') -> Dict[str, Tensor]:
228
307
  r"""Tokenizing the text associate with each node, running in cpu.
229
308
 
230
309
  Args:
231
310
  batch_size (Optional[int]): batch size of list of text for
232
311
  generating emebdding
312
+ text_type (Optional[str]): type of text
233
313
  Returns:
234
314
  Dict[str, torch.Tensor]: tokenized graph
235
315
  """
316
+ assert text_type in ['raw_text', 'llm_explanation', 'all']
317
+ if text_type == 'raw_text':
318
+ _text = self.text
319
+ elif text_type == 'llm_explanation':
320
+ _text = self.llm_explanation
321
+ elif text_type == 'all':
322
+ if self.text is None or self.llm_explanation is None:
323
+ raise ValueError("The TAGDataset need text and llm explanation"
324
+ "for tokenizing all text")
325
+ _text = [
326
+ f'{raw_txt} Explanation: {exp_txt}'
327
+ for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
328
+ ]
329
+
236
330
  data_len = 0
237
- if self.text is not None:
238
- data_len = len(self.text)
331
+ if _text is not None:
332
+ data_len = len(_text)
239
333
  else:
240
334
  raise ValueError("The TAGDataset need text for tokenization")
241
335
  token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
242
- path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
336
+ path = os.path.join(self.processed_dir, 'token', text_type,
337
+ self.tokenizer_name)
243
338
  # Check if the .pt files already exist
244
339
  token_files_exist = any(
245
340
  os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
@@ -256,12 +351,12 @@ class TAGDataset(InMemoryDataset):
256
351
  all_encoded_token = {k: [] for k in token_keys}
257
352
  pbar = tqdm(total=data_len)
258
353
 
259
- pbar.set_description('Tokenizing Text Attributed Graph')
354
+ pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
260
355
  for i in range(0, data_len, batch_size):
261
356
  end_index = min(data_len, i + batch_size)
262
- token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
263
- padding='max_length', truncation=True,
264
- max_length=512, return_tensors="pt")
357
+ token = self.tokenizer(_text[i:end_index], padding='max_length',
358
+ truncation=True, max_length=512,
359
+ return_tensors="pt")
265
360
  for k in token.keys():
266
361
  all_encoded_token[k].append(token[k])
267
362
  pbar.update(end_index - i)
@@ -289,10 +384,18 @@ class TAGDataset(InMemoryDataset):
289
384
 
290
385
  Args:
291
386
  tag_dataset (TAGDataset): the parent dataset
387
+ text_type (str): type of text
292
388
  """
293
- def __init__(self, tag_dataset: 'TAGDataset') -> None:
389
+ def __init__(self, tag_dataset: 'TAGDataset',
390
+ text_type: str = 'raw_text') -> None:
391
+ assert text_type in ['raw_text', 'llm_explanation', 'all']
294
392
  self.tag_dataset = tag_dataset
295
- self.token = tag_dataset.token
393
+ if text_type == 'raw_text':
394
+ self.token = tag_dataset.token
395
+ elif text_type == 'llm_explanation':
396
+ self.token = tag_dataset.llm_explanation_token
397
+ elif text_type == 'all':
398
+ self.token = tag_dataset.all_token
296
399
  assert tag_dataset._data is not None
297
400
  self._data = tag_dataset._data
298
401
 
@@ -312,7 +415,8 @@ class TAGDataset(InMemoryDataset):
312
415
 
313
416
  # for LM training
314
417
  def __getitem__(
315
- self, node_id: IndexType
418
+ self,
419
+ node_id: IndexType,
316
420
  ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
317
421
  r"""This function will override the function in
318
422
  torch.utils.data.Dataset, and will be called when you
@@ -343,8 +447,8 @@ class TAGDataset(InMemoryDataset):
343
447
  def __repr__(self) -> str:
344
448
  return f'{self.__class__.__name__}()'
345
449
 
346
- def to_text_dataset(self) -> TextDataset:
450
+ def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
347
451
  r"""Factory Build text dataset from Text Attributed Graph Dataset
348
452
  each data point is node's associated text token.
349
453
  """
350
- return TAGDataset.TextDataset(self)
454
+ return TAGDataset.TextDataset(self, text_type)
@@ -256,14 +256,6 @@ def filter_custom_hetero_store(
256
256
  # Construct a new `HeteroData` object:
257
257
  data = custom_cls() if custom_cls is not None else HeteroData()
258
258
 
259
- # Filter edge storage:
260
- # TODO support edge attributes
261
- for attr in graph_store.get_all_edge_attrs():
262
- key = attr.edge_type
263
- if key in row_dict and key in col_dict:
264
- edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
265
- data[attr.edge_type].edge_index = edge_index
266
-
267
259
  # Filter node storage:
268
260
  required_attrs = []
269
261
  for attr in feature_store.get_all_tensor_attrs():
@@ -280,6 +272,14 @@ def filter_custom_hetero_store(
280
272
  for i, attr in enumerate(required_attrs):
281
273
  data[attr.group_name][attr.attr_name] = tensors[i]
282
274
 
275
+ # Filter edge storage:
276
+ # TODO support edge attributes
277
+ for attr in graph_store.get_all_edge_attrs():
278
+ key = attr.edge_type
279
+ if key in row_dict and key in col_dict:
280
+ edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
281
+ data[attr.edge_type].edge_index = edge_index
282
+
283
283
  return data
284
284
 
285
285
 
@@ -53,7 +53,7 @@ class LinkPredMetricData:
53
53
 
54
54
  # Flatten both prediction and ground-truth indices, and determine
55
55
  # overlaps afterwards via `torch.searchsorted`.
56
- max_index = max( # type: ignore
56
+ max_index = max(
57
57
  self.pred_index_mat.max()
58
58
  if self.pred_index_mat.numel() > 0 else 0,
59
59
  self.edge_label_index[1].max()
@@ -820,9 +820,10 @@ class LinkPredPersonalization(_LinkPredMetric):
820
820
  right = pred[col.cpu()].to(device)
821
821
 
822
822
  # Use offset to work around applying `isin` along a specific dim:
823
- i = max(left.max(), right.max()) + 1 # type: ignore
824
- i = torch.arange(0, i * row.size(0), i, device=device).view(-1, 1)
825
- isin = torch.isin(left + i, right + i)
823
+ i = max(int(left.max()), int(right.max())) + 1
824
+ idx = torch.arange(0, i * row.size(0), i, device=device)
825
+ idx = idx.view(-1, 1)
826
+ isin = torch.isin(left + idx, right + idx)
826
827
 
827
828
  # Compute personalization via average inverse cosine similarity:
828
829
  cos = isin.sum(dim=-1) / pred.size(1)