torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,57 +82,67 @@ class DataGenerator(object):
82
82
 
83
83
 
84
84
  def get_auto_embedding_dim(num_classes):
85
- """ Calculate the dim of embedding vector according to number of classes in the category
86
- emb_dim = [6 * (num_classes)^(1/4)]
87
- reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
88
- Args:
89
- num_classes: number of classes in the category
90
-
91
- Returns:
92
- the dim of embedding vector
85
+ """Calculate embedding dim by category size.
86
+
87
+ Uses ``emb_dim = floor(6 * num_classes**0.25)`` from DCN (ADKDD'17).
88
+
89
+ Parameters
90
+ ----------
91
+ num_classes : int
92
+ Number of categorical classes.
93
+
94
+ Returns
95
+ -------
96
+ int
97
+ Recommended embedding dimension.
93
98
  """
94
99
  return int(np.floor(6 * np.pow(num_classes, 0.25)))
95
100
 
96
101
 
97
102
  def get_loss_func(task_type="classification"):
103
+ """Return default loss by task type."""
98
104
  if task_type == "classification":
99
105
  return torch.nn.BCELoss()
100
- elif task_type == "regression":
106
+ if task_type == "regression":
101
107
  return torch.nn.MSELoss()
102
- else:
103
- raise ValueError("task_type must be classification or regression")
108
+ raise ValueError("task_type must be classification or regression")
104
109
 
105
110
 
106
111
  def get_metric_func(task_type="classification"):
112
+ """Return default metric by task type."""
107
113
  if task_type == "classification":
108
114
  return roc_auc_score
109
- elif task_type == "regression":
115
+ if task_type == "regression":
110
116
  return mean_squared_error
111
- else:
112
- raise ValueError("task_type must be classification or regression")
117
+ raise ValueError("task_type must be classification or regression")
113
118
 
114
119
 
115
120
  def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols=[], min_item=0, shuffle=True, max_len=50):
116
- """generate sequence feature and negative sample for ranking.
117
-
118
- Args:
119
- data (pd.DataFrame): the raw data.
120
- user_col (str): the col name of user_id
121
- item_col (str): the col name of item_id
122
- time_col (str): the col name of timestamp
123
- item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
124
- sample_method (int, optional): the negative sample method `{
125
- 0: "random sampling",
126
- 1: "popularity sampling method used in word2vec",
127
- 2: "popularity sampling method by `log(count+1)+1e-6`",
128
- 3: "tencent RALM sampling"}`.
129
- Defaults to 0.
130
- min_item (int, optional): the min item each user must have. Defaults to 0.
131
- shuffle (bool, optional): shulle if True
132
- max_len (int, optional): the max length of a user history sequence.
133
-
134
- Returns:
135
- pd.DataFrame: split train, val and test data with sequence features by time.
121
+ """Generate sequence features and negatives for ranking.
122
+
123
+ Parameters
124
+ ----------
125
+ data : pd.DataFrame
126
+ Raw interaction data.
127
+ user_col : str
128
+ User id column name.
129
+ item_col : str
130
+ Item id column name.
131
+ time_col : str
132
+ Timestamp column name.
133
+ item_attribute_cols : list[str], optional
134
+ Additional item attribute columns to include in sequences.
135
+ min_item : int, default=0
136
+ Minimum items per user; users below are dropped.
137
+ shuffle : bool, default=True
138
+ Shuffle train/val/test.
139
+ max_len : int, default=50
140
+ Max history length.
141
+
142
+ Returns
143
+ -------
144
+ tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
145
+ Train, validation, and test data with sequence features.
136
146
  """
137
147
  for feat in data:
138
148
  le = LabelEncoder()
@@ -205,12 +215,17 @@ def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols
205
215
 
206
216
 
207
217
  def df_to_dict(data):
208
- """
209
- Convert the DataFrame to a dict type input that the network can accept
210
- Args:
211
- data (pd.DataFrame): datasets of type DataFrame
212
- Returns:
213
- The converted dict, which can be used directly into the input network
218
+ """Convert DataFrame to dict inputs accepted by models.
219
+
220
+ Parameters
221
+ ----------
222
+ data : pd.DataFrame
223
+ Input dataframe.
224
+
225
+ Returns
226
+ -------
227
+ dict
228
+ Mapping of column name to numpy array.
214
229
  """
215
230
  data_dict = data.to_dict('list')
216
231
  for key in data.keys():
@@ -226,20 +241,28 @@ def neg_sample(click_hist, item_size):
226
241
 
227
242
 
228
243
  def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
229
- """ Pads sequences (list of list) to the ndarray of same length.
230
- This is an equivalent implementation of tf.keras.preprocessing.sequence.pad_sequences
231
- reference: https://github.com/huawei-noah/benchmark/tree/main/FuxiCTR/fuxictr
232
-
233
- Args:
234
- sequences (pd.DataFrame): data that needs to pad or truncate
235
- maxlen (int): maximum sequence length. Defaults to None.
236
- dtype (str, optional): Defaults to 'int32'.
237
- padding (str, optional): if len(sequences) less than maxlen, padding style, {'pre', 'post'}. Defaults to 'pre'.
238
- truncating (str, optional): if len(sequences) more than maxlen, truncate style, {'pre', 'post'}. Defaults to 'pre'.
239
- value (_type_, optional): Defaults to 0..
240
-
241
- Returns:
242
- _type_: _description_
244
+ """Pad list-of-lists sequences to equal length.
245
+
246
+ Equivalent to ``tf.keras.preprocessing.sequence.pad_sequences``.
247
+
248
+ Parameters
249
+ ----------
250
+ sequences : Sequence[Sequence]
251
+ Input sequences.
252
+ maxlen : int, optional
253
+ Maximum length; computed if None.
254
+ dtype : str, default='int32'
255
+ padding : {'pre', 'post'}, default='pre'
256
+ Padding direction.
257
+ truncating : {'pre', 'post'}, default='pre'
258
+ Truncation direction.
259
+ value : float, default=0.0
260
+ Padding value.
261
+
262
+ Returns
263
+ -------
264
+ np.ndarray
265
+ Padded array of shape (n_samples, maxlen).
243
266
  """
244
267
 
245
268
  assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
@@ -265,13 +288,19 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncati
265
288
 
266
289
 
267
290
  def array_replace_with_dict(array, dic):
268
- """Replace values in NumPy array based on dictionary.
269
- Args:
270
- array (np.array): a numpy array
271
- dic (dict): a map dict
272
-
273
- Returns:
274
- np.array: array with replace
291
+ """Replace values in numpy array using a mapping dict.
292
+
293
+ Parameters
294
+ ----------
295
+ array : np.ndarray
296
+ Input array.
297
+ dic : dict
298
+ Mapping from old to new values.
299
+
300
+ Returns
301
+ -------
302
+ np.ndarray
303
+ Array with values replaced.
275
304
  """
276
305
  # Extract out keys and values
277
306
  k = np.array(list(dic.keys()))
@@ -284,19 +313,25 @@ def array_replace_with_dict(array, dic):
284
313
 
285
314
  # Temporarily reserved for testing purposes(1985312383@qq.com)
286
315
  def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
287
- """Build a sequence of user's history by time.
288
-
289
- Args:
290
- data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
291
- seq_feature_col (list): specify the column name that needs to generate sequence features, and its sequence features will be generated according to userid.
292
- max_len (int): the max length of a user history sequence.
293
- drop_short (int): remove some inactive user who's sequence length < drop_short.
294
- shuffle (bool): shuffle data if true.
295
-
296
- Returns:
297
- train (pd.DataFrame): target item will be each item before last two items.
298
- val (pd.DataFrame): target item is the second to last item of user's history sequence.
299
- test (pd.DataFrame): target item is the last item of user's history sequence.
316
+ """Build user history sequences by time.
317
+
318
+ Parameters
319
+ ----------
320
+ data : pd.DataFrame
321
+ Must contain ``user_id, item_id, cate_id, time``.
322
+ seq_feature_col : list, default ['item_id', 'cate_id']
323
+ Columns to generate sequence features.
324
+ max_len : int, default=50
325
+ Max history length.
326
+ drop_short : int, default=3
327
+ Drop users with sequence length < drop_short.
328
+ shuffle : bool, default=True
329
+ Shuffle outputs.
330
+
331
+ Returns
332
+ -------
333
+ tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
334
+ Train/val/test splits with sequence features.
300
335
  """
301
336
  for feat in data:
302
337
  le = LabelEncoder()
@@ -357,30 +392,32 @@ def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50
357
392
 
358
393
 
359
394
  class SeqDataset(Dataset):
360
- """Sequence dataset for HSTU-style generative models.
361
-
362
- This class wraps precomputed sequence features for next-item prediction
363
- tasks, including tokens, positions, time differences and targets.
364
-
365
- Args:
366
- seq_tokens (np.ndarray): Token ids of shape ``(num_samples, seq_len)``.
367
- seq_positions (np.ndarray): Position indices of shape
368
- ``(num_samples, seq_len)``.
369
- targets (np.ndarray): Target token ids of shape ``(num_samples,)``.
370
- seq_time_diffs (np.ndarray): Time-difference features of shape
371
- ``(num_samples, seq_len)``.
372
-
373
- Shape:
374
- - Output: A tuple ``(seq_tokens, seq_positions, seq_time_diffs, target)``.
375
-
376
- Example:
377
- >>> seq_tokens = np.random.randint(0, 1000, (100, 256))
378
- >>> seq_positions = np.arange(256)[np.newaxis, :].repeat(100, axis=0)
379
- >>> seq_time_diffs = np.random.randint(0, 86400, (100, 256))
380
- >>> targets = np.random.randint(0, 1000, (100,))
381
- >>> dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
382
- >>> len(dataset)
383
- 100
395
+ """Sequence dataset for HSTU-style next-item prediction.
396
+
397
+ Parameters
398
+ ----------
399
+ seq_tokens : np.ndarray
400
+ Token ids, shape ``(num_samples, seq_len)``.
401
+ seq_positions : np.ndarray
402
+ Position indices, shape ``(num_samples, seq_len)``.
403
+ targets : np.ndarray
404
+ Target token ids, shape ``(num_samples,)``.
405
+ seq_time_diffs : np.ndarray
406
+ Time-difference features, shape ``(num_samples, seq_len)``.
407
+
408
+ Shape
409
+ -----
410
+ Output tuple: ``(seq_tokens, seq_positions, seq_time_diffs, target)``
411
+
412
+ Examples
413
+ --------
414
+ >>> seq_tokens = np.random.randint(0, 1000, (100, 256))
415
+ >>> seq_positions = np.arange(256)[np.newaxis, :].repeat(100, axis=0)
416
+ >>> seq_time_diffs = np.random.randint(0, 86400, (100, 256))
417
+ >>> targets = np.random.randint(0, 1000, (100,))
418
+ >>> dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
419
+ >>> len(dataset)
420
+ 100
384
421
  """
385
422
 
386
423
  def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
@@ -414,29 +451,25 @@ class SeqDataset(Dataset):
414
451
 
415
452
 
416
453
  class SequenceDataGenerator(object):
417
- """Sequence data generator used for HSTU-style models.
418
-
419
- This helper wraps a :class:`SeqDataset` and provides convenient utilities
420
- to construct train/val/test ``DataLoader`` objects.
421
-
422
- Args:
423
- seq_tokens (np.ndarray): Token ids of shape ``(num_samples, seq_len)``.
424
- seq_positions (np.ndarray): Position indices of shape
425
- ``(num_samples, seq_len)``.
426
- targets (np.ndarray): Target token ids of shape ``(num_samples,)``.
427
- seq_time_diffs (np.ndarray): Time-difference features of shape
428
- ``(num_samples, seq_len)``.
429
-
430
- Methods:
431
- generate_dataloader: Build train/val/test data loaders.
432
-
433
- Example:
434
- >>> seq_tokens = np.random.randint(0, 1000, (1000, 256))
435
- >>> seq_positions = np.arange(256)[np.newaxis, :].repeat(1000, axis=0)
436
- >>> seq_time_diffs = np.random.randint(0, 86400, (1000, 256))
437
- >>> targets = np.random.randint(0, 1000, (1000,))
438
- >>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
439
- >>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
454
+ """Sequence data generator for HSTU-style models.
455
+
456
+ Wraps :class:`SeqDataset` and builds train/val/test loaders.
457
+
458
+ Parameters
459
+ ----------
460
+ seq_tokens : np.ndarray
461
+ Token ids, shape ``(num_samples, seq_len)``.
462
+ seq_positions : np.ndarray
463
+ Position indices, shape ``(num_samples, seq_len)``.
464
+ targets : np.ndarray
465
+ Target token ids, shape ``(num_samples,)``.
466
+ seq_time_diffs : np.ndarray
467
+ Time-difference features, shape ``(num_samples, seq_len)``.
468
+
469
+ Examples
470
+ --------
471
+ >>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
472
+ >>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
440
473
  """
441
474
 
442
475
  def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
@@ -449,44 +482,57 @@ class SequenceDataGenerator(object):
449
482
  # Underlying dataset
450
483
  self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
451
484
 
452
- def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
453
- """生成数据加载器.
454
-
455
- Args:
456
- batch_size (int): 批大小,默认32
457
- num_workers (int): 数据加载线程数,默认0
458
- split_ratio (tuple): 分割比例 (train, val, test),默认(0.7, 0.1, 0.2)
459
-
460
- Returns:
461
- tuple: (train_loader, val_loader, test_loader)
462
-
463
- Example:
464
- >>> train_loader, val_loader, test_loader = gen.generate_dataloader(
465
- ... batch_size=32,
466
- ... num_workers=4,
467
- ... split_ratio=(0.7, 0.1, 0.2)
468
- ... )
485
+ def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
486
+ """Generate dataloader(s) from the dataset.
487
+
488
+ Parameters
489
+ ----------
490
+ batch_size : int, default=32
491
+ Batch size for DataLoader.
492
+ num_workers : int, default=0
493
+ Number of workers for DataLoader.
494
+ split_ratio : tuple or None, default=None
495
+ If None, returns a single DataLoader without splitting the data.
496
+ If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
497
+ (train_loader, val_loader, test_loader).
498
+ shuffle : bool, default=True
499
+ Whether to shuffle data. Only applies when split_ratio is None.
500
+ When split_ratio is provided, train data is always shuffled.
501
+
502
+ Returns
503
+ -------
504
+ tuple
505
+ If split_ratio is None: returns (dataloader,)
506
+ If split_ratio is provided: returns (train_loader, val_loader, test_loader)
507
+
508
+ Examples
509
+ --------
510
+ # Case 1: Data already split, just create loader
511
+ >>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
512
+ >>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
513
+
514
+ # Case 2: Auto-split data into train/val/test
515
+ >>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
516
+ >>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
517
+ ... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
469
518
  """
470
519
  if split_ratio is None:
471
- split_ratio = (0.7, 0.1, 0.2)
520
+ # No split - data is already divided, just create a single DataLoader
521
+ dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
522
+ return (dataloader,)
472
523
 
473
- # 验证分割比例
524
+ # Split data into train/val/test
474
525
  assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
475
526
 
476
- # 计算分割大小
477
527
  total_size = len(self.dataset)
478
528
  train_size = int(total_size * split_ratio[0])
479
529
  val_size = int(total_size * split_ratio[1])
480
530
  test_size = total_size - train_size - val_size
481
531
 
482
- # 分割数据集
483
532
  train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
484
533
 
485
- # 创建数据加载器
486
534
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
487
-
488
535
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
489
-
490
536
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
491
537
 
492
538
  return train_loader, val_loader, test_loader
@@ -6,25 +6,27 @@ import torch.nn as nn
6
6
 
7
7
 
8
8
  class RelPosBias(nn.Module):
9
- """Relative position bias module.
10
-
11
- This module is used in HSTU self-attention layers to provide a learnable
12
- bias that depends on the relative distance between sequence positions. It
13
- can be combined with time-based bucketing when needed.
14
-
15
- Args:
16
- n_heads (int): Number of attention heads.
17
- max_seq_len (int): Maximum supported sequence length.
18
- num_buckets (int): Number of relative position buckets. Default: 32.
19
-
20
- Shape:
21
- - Output: ``(1, n_heads, seq_len, seq_len)``
22
-
23
- Example:
24
- >>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
25
- >>> bias = rel_pos_bias(256)
26
- >>> bias.shape
27
- torch.Size([1, 8, 256, 256])
9
+ """Relative position bias for attention.
10
+
11
+ Parameters
12
+ ----------
13
+ n_heads : int
14
+ Number of attention heads.
15
+ max_seq_len : int
16
+ Maximum supported sequence length.
17
+ num_buckets : int, default=32
18
+ Number of relative position buckets.
19
+
20
+ Shape
21
+ -----
22
+ Output: ``(1, n_heads, seq_len, seq_len)``
23
+
24
+ Examples
25
+ --------
26
+ >>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
27
+ >>> bias = rel_pos_bias(256)
28
+ >>> bias.shape
29
+ torch.Size([1, 8, 256, 256])
28
30
  """
29
31
 
30
32
  def __init__(self, n_heads, max_seq_len, num_buckets=32):
@@ -87,22 +89,20 @@ class RelPosBias(nn.Module):
87
89
 
88
90
 
89
91
  class VocabMask(nn.Module):
90
- """Vocabulary mask used to constrain generation during inference.
91
-
92
- At inference time this module can be used to mask out invalid item IDs
93
- so that the model never generates them.
94
-
95
- Args:
96
- vocab_size (int): Vocabulary size.
97
- invalid_items (list, optional): List of invalid item IDs to be masked.
98
-
99
- Methods:
100
- apply_mask: Apply the mask to logits.
101
-
102
- Example:
103
- >>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
104
- >>> logits = torch.randn(32, 1000)
105
- >>> masked_logits = mask.apply_mask(logits)
92
+ """Vocabulary mask to block invalid items at inference.
93
+
94
+ Parameters
95
+ ----------
96
+ vocab_size : int
97
+ Vocabulary size.
98
+ invalid_items : list, optional
99
+ IDs to mask out.
100
+
101
+ Examples
102
+ --------
103
+ >>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
104
+ >>> logits = torch.randn(32, 1000)
105
+ >>> masked_logits = mask.apply_mask(logits)
106
106
  """
107
107
 
108
108
  def __init__(self, vocab_size, invalid_items=None):
@@ -123,13 +123,17 @@ class VocabMask(nn.Module):
123
123
  self.mask[item_id] = False
124
124
 
125
125
  def apply_mask(self, logits):
126
- """应用掩码到logits.
127
-
128
- Args:
129
- logits (Tensor): 模型输出logits,shape: (..., vocab_size)
130
-
131
- Returns:
132
- Tensor: 掩码后的logits
126
+ """Apply mask to logits.
127
+
128
+ Parameters
129
+ ----------
130
+ logits : Tensor
131
+ Model logits, shape ``(..., vocab_size)``.
132
+
133
+ Returns
134
+ -------
135
+ Tensor
136
+ Masked logits.
133
137
  """
134
138
  # 将无效item的logits设置为极小值
135
139
  masked_logits = logits.clone()
@@ -139,26 +143,25 @@ class VocabMask(nn.Module):
139
143
 
140
144
 
141
145
  class VocabMapper(object):
142
- """Simple mapper between ``item_id`` and ``token_id``.
143
-
144
- In sequence generation tasks we often treat item IDs as tokens. This
145
- helper keeps a trivial identity mapping but makes the intent explicit and
146
- allows future extensions (e.g., reserved IDs, remapping, etc.).
147
-
148
- Args:
149
- vocab_size (int): Size of the vocabulary.
150
- pad_id (int): ID used for the PAD token. Default: 0.
151
- unk_id (int): ID used for unknown tokens. Default: 1.
152
-
153
- Methods:
154
- encode: Map ``item_id`` to ``token_id``.
155
- decode: Map ``token_id`` back to ``item_id``.
156
-
157
- Example:
158
- >>> mapper = VocabMapper(vocab_size=1000)
159
- >>> item_ids = np.array([10, 20, 30])
160
- >>> token_ids = mapper.encode(item_ids)
161
- >>> decoded_ids = mapper.decode(token_ids)
146
+ """Identity mapper between ``item_id`` and ``token_id``.
147
+
148
+ Useful for sequence generation where items are treated as tokens.
149
+
150
+ Parameters
151
+ ----------
152
+ vocab_size : int
153
+ Vocabulary size.
154
+ pad_id : int, default=0
155
+ PAD token id.
156
+ unk_id : int, default=1
157
+ Unknown token id.
158
+
159
+ Examples
160
+ --------
161
+ >>> mapper = VocabMapper(vocab_size=1000)
162
+ >>> item_ids = np.array([10, 20, 30])
163
+ >>> token_ids = mapper.encode(item_ids)
164
+ >>> decoded_ids = mapper.decode(token_ids)
162
165
  """
163
166
 
164
167
  def __init__(self, vocab_size, pad_id=0, unk_id=1):
@@ -172,26 +175,34 @@ class VocabMapper(object):
172
175
  self.token2item = np.arange(vocab_size)
173
176
 
174
177
  def encode(self, item_ids):
175
- """将item_id转换为token_id.
176
-
177
- Args:
178
- item_ids (np.ndarray): item ID数组
179
-
180
- Returns:
181
- np.ndarray: token ID数组
178
+ """Convert item_ids to token_ids.
179
+
180
+ Parameters
181
+ ----------
182
+ item_ids : np.ndarray
183
+ Item ids.
184
+
185
+ Returns
186
+ -------
187
+ np.ndarray
188
+ Token ids.
182
189
  """
183
190
  # 处理超出范围的item_id
184
191
  token_ids = np.where((item_ids >= 0) & (item_ids < self.vocab_size), item_ids, self.unk_id)
185
192
  return token_ids
186
193
 
187
194
  def decode(self, token_ids):
188
- """将token_id转换为item_id.
189
-
190
- Args:
191
- token_ids (np.ndarray): token ID数组
192
-
193
- Returns:
194
- np.ndarray: item ID数组
195
+ """Convert token_ids back to item_ids.
196
+
197
+ Parameters
198
+ ----------
199
+ token_ids : np.ndarray
200
+ Token ids.
201
+
202
+ Returns
203
+ -------
204
+ np.ndarray
205
+ Item ids.
195
206
  """
196
207
  # 处理超出范围的token_id
197
208
  item_ids = np.where((token_ids >= 0) & (token_ids < self.vocab_size), token_ids, self.unk_id)