torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +228 -159
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/data/dataset.py +18 -31
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +12 -2
- torch_rechub/trainers/match_trainer.py +13 -2
- torch_rechub/trainers/mtl_trainer.py +12 -2
- torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +191 -145
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +34 -18
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +24 -18
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/utils/data.py
CHANGED
|
@@ -82,57 +82,67 @@ class DataGenerator(object):
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
def get_auto_embedding_dim(num_classes):
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
85
|
+
"""Calculate embedding dim by category size.
|
|
86
|
+
|
|
87
|
+
Uses ``emb_dim = floor(6 * num_classes**0.25)`` from DCN (ADKDD'17).
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
num_classes : int
|
|
92
|
+
Number of categorical classes.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
int
|
|
97
|
+
Recommended embedding dimension.
|
|
93
98
|
"""
|
|
94
99
|
return int(np.floor(6 * np.pow(num_classes, 0.25)))
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
def get_loss_func(task_type="classification"):
|
|
103
|
+
"""Return default loss by task type."""
|
|
98
104
|
if task_type == "classification":
|
|
99
105
|
return torch.nn.BCELoss()
|
|
100
|
-
|
|
106
|
+
if task_type == "regression":
|
|
101
107
|
return torch.nn.MSELoss()
|
|
102
|
-
|
|
103
|
-
raise ValueError("task_type must be classification or regression")
|
|
108
|
+
raise ValueError("task_type must be classification or regression")
|
|
104
109
|
|
|
105
110
|
|
|
106
111
|
def get_metric_func(task_type="classification"):
|
|
112
|
+
"""Return default metric by task type."""
|
|
107
113
|
if task_type == "classification":
|
|
108
114
|
return roc_auc_score
|
|
109
|
-
|
|
115
|
+
if task_type == "regression":
|
|
110
116
|
return mean_squared_error
|
|
111
|
-
|
|
112
|
-
raise ValueError("task_type must be classification or regression")
|
|
117
|
+
raise ValueError("task_type must be classification or regression")
|
|
113
118
|
|
|
114
119
|
|
|
115
120
|
def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols=[], min_item=0, shuffle=True, max_len=50):
|
|
116
|
-
"""
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
121
|
+
"""Generate sequence features and negatives for ranking.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
data : pd.DataFrame
|
|
126
|
+
Raw interaction data.
|
|
127
|
+
user_col : str
|
|
128
|
+
User id column name.
|
|
129
|
+
item_col : str
|
|
130
|
+
Item id column name.
|
|
131
|
+
time_col : str
|
|
132
|
+
Timestamp column name.
|
|
133
|
+
item_attribute_cols : list[str], optional
|
|
134
|
+
Additional item attribute columns to include in sequences.
|
|
135
|
+
min_item : int, default=0
|
|
136
|
+
Minimum items per user; users below are dropped.
|
|
137
|
+
shuffle : bool, default=True
|
|
138
|
+
Shuffle train/val/test.
|
|
139
|
+
max_len : int, default=50
|
|
140
|
+
Max history length.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
|
145
|
+
Train, validation, and test data with sequence features.
|
|
136
146
|
"""
|
|
137
147
|
for feat in data:
|
|
138
148
|
le = LabelEncoder()
|
|
@@ -205,12 +215,17 @@ def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols
|
|
|
205
215
|
|
|
206
216
|
|
|
207
217
|
def df_to_dict(data):
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
218
|
+
"""Convert DataFrame to dict inputs accepted by models.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
data : pd.DataFrame
|
|
223
|
+
Input dataframe.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
dict
|
|
228
|
+
Mapping of column name to numpy array.
|
|
214
229
|
"""
|
|
215
230
|
data_dict = data.to_dict('list')
|
|
216
231
|
for key in data.keys():
|
|
@@ -226,20 +241,28 @@ def neg_sample(click_hist, item_size):
|
|
|
226
241
|
|
|
227
242
|
|
|
228
243
|
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
|
|
229
|
-
"""
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
244
|
+
"""Pad list-of-lists sequences to equal length.
|
|
245
|
+
|
|
246
|
+
Equivalent to ``tf.keras.preprocessing.sequence.pad_sequences``.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
sequences : Sequence[Sequence]
|
|
251
|
+
Input sequences.
|
|
252
|
+
maxlen : int, optional
|
|
253
|
+
Maximum length; computed if None.
|
|
254
|
+
dtype : str, default='int32'
|
|
255
|
+
padding : {'pre', 'post'}, default='pre'
|
|
256
|
+
Padding direction.
|
|
257
|
+
truncating : {'pre', 'post'}, default='pre'
|
|
258
|
+
Truncation direction.
|
|
259
|
+
value : float, default=0.0
|
|
260
|
+
Padding value.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
np.ndarray
|
|
265
|
+
Padded array of shape (n_samples, maxlen).
|
|
243
266
|
"""
|
|
244
267
|
|
|
245
268
|
assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
|
|
@@ -265,13 +288,19 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncati
|
|
|
265
288
|
|
|
266
289
|
|
|
267
290
|
def array_replace_with_dict(array, dic):
|
|
268
|
-
"""Replace values in
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
291
|
+
"""Replace values in numpy array using a mapping dict.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
array : np.ndarray
|
|
296
|
+
Input array.
|
|
297
|
+
dic : dict
|
|
298
|
+
Mapping from old to new values.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
np.ndarray
|
|
303
|
+
Array with values replaced.
|
|
275
304
|
"""
|
|
276
305
|
# Extract out keys and values
|
|
277
306
|
k = np.array(list(dic.keys()))
|
|
@@ -284,19 +313,25 @@ def array_replace_with_dict(array, dic):
|
|
|
284
313
|
|
|
285
314
|
# Temporarily reserved for testing purposes(1985312383@qq.com)
|
|
286
315
|
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
|
|
287
|
-
"""Build
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
316
|
+
"""Build user history sequences by time.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
data : pd.DataFrame
|
|
321
|
+
Must contain ``user_id, item_id, cate_id, time``.
|
|
322
|
+
seq_feature_col : list, default ['item_id', 'cate_id']
|
|
323
|
+
Columns to generate sequence features.
|
|
324
|
+
max_len : int, default=50
|
|
325
|
+
Max history length.
|
|
326
|
+
drop_short : int, default=3
|
|
327
|
+
Drop users with sequence length < drop_short.
|
|
328
|
+
shuffle : bool, default=True
|
|
329
|
+
Shuffle outputs.
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
|
334
|
+
Train/val/test splits with sequence features.
|
|
300
335
|
"""
|
|
301
336
|
for feat in data:
|
|
302
337
|
le = LabelEncoder()
|
|
@@ -357,30 +392,32 @@ def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50
|
|
|
357
392
|
|
|
358
393
|
|
|
359
394
|
class SeqDataset(Dataset):
|
|
360
|
-
"""Sequence dataset for HSTU-style
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
Shape
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
395
|
+
"""Sequence dataset for HSTU-style next-item prediction.
|
|
396
|
+
|
|
397
|
+
Parameters
|
|
398
|
+
----------
|
|
399
|
+
seq_tokens : np.ndarray
|
|
400
|
+
Token ids, shape ``(num_samples, seq_len)``.
|
|
401
|
+
seq_positions : np.ndarray
|
|
402
|
+
Position indices, shape ``(num_samples, seq_len)``.
|
|
403
|
+
targets : np.ndarray
|
|
404
|
+
Target token ids, shape ``(num_samples,)``.
|
|
405
|
+
seq_time_diffs : np.ndarray
|
|
406
|
+
Time-difference features, shape ``(num_samples, seq_len)``.
|
|
407
|
+
|
|
408
|
+
Shape
|
|
409
|
+
-----
|
|
410
|
+
Output tuple: ``(seq_tokens, seq_positions, seq_time_diffs, target)``
|
|
411
|
+
|
|
412
|
+
Examples
|
|
413
|
+
--------
|
|
414
|
+
>>> seq_tokens = np.random.randint(0, 1000, (100, 256))
|
|
415
|
+
>>> seq_positions = np.arange(256)[np.newaxis, :].repeat(100, axis=0)
|
|
416
|
+
>>> seq_time_diffs = np.random.randint(0, 86400, (100, 256))
|
|
417
|
+
>>> targets = np.random.randint(0, 1000, (100,))
|
|
418
|
+
>>> dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
419
|
+
>>> len(dataset)
|
|
420
|
+
100
|
|
384
421
|
"""
|
|
385
422
|
|
|
386
423
|
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
@@ -414,29 +451,25 @@ class SeqDataset(Dataset):
|
|
|
414
451
|
|
|
415
452
|
|
|
416
453
|
class SequenceDataGenerator(object):
|
|
417
|
-
"""Sequence data generator
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
>>> seq_time_diffs = np.random.randint(0, 86400, (1000, 256))
|
|
437
|
-
>>> targets = np.random.randint(0, 1000, (1000,))
|
|
438
|
-
>>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
439
|
-
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
|
|
454
|
+
"""Sequence data generator for HSTU-style models.
|
|
455
|
+
|
|
456
|
+
Wraps :class:`SeqDataset` and builds train/val/test loaders.
|
|
457
|
+
|
|
458
|
+
Parameters
|
|
459
|
+
----------
|
|
460
|
+
seq_tokens : np.ndarray
|
|
461
|
+
Token ids, shape ``(num_samples, seq_len)``.
|
|
462
|
+
seq_positions : np.ndarray
|
|
463
|
+
Position indices, shape ``(num_samples, seq_len)``.
|
|
464
|
+
targets : np.ndarray
|
|
465
|
+
Target token ids, shape ``(num_samples,)``.
|
|
466
|
+
seq_time_diffs : np.ndarray
|
|
467
|
+
Time-difference features, shape ``(num_samples, seq_len)``.
|
|
468
|
+
|
|
469
|
+
Examples
|
|
470
|
+
--------
|
|
471
|
+
>>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
472
|
+
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
|
|
440
473
|
"""
|
|
441
474
|
|
|
442
475
|
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
@@ -449,44 +482,57 @@ class SequenceDataGenerator(object):
|
|
|
449
482
|
# Underlying dataset
|
|
450
483
|
self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
451
484
|
|
|
452
|
-
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
|
|
453
|
-
"""
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
485
|
+
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
|
|
486
|
+
"""Generate dataloader(s) from the dataset.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
batch_size : int, default=32
|
|
491
|
+
Batch size for DataLoader.
|
|
492
|
+
num_workers : int, default=0
|
|
493
|
+
Number of workers for DataLoader.
|
|
494
|
+
split_ratio : tuple or None, default=None
|
|
495
|
+
If None, returns a single DataLoader without splitting the data.
|
|
496
|
+
If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
|
|
497
|
+
(train_loader, val_loader, test_loader).
|
|
498
|
+
shuffle : bool, default=True
|
|
499
|
+
Whether to shuffle data. Only applies when split_ratio is None.
|
|
500
|
+
When split_ratio is provided, train data is always shuffled.
|
|
501
|
+
|
|
502
|
+
Returns
|
|
503
|
+
-------
|
|
504
|
+
tuple
|
|
505
|
+
If split_ratio is None: returns (dataloader,)
|
|
506
|
+
If split_ratio is provided: returns (train_loader, val_loader, test_loader)
|
|
507
|
+
|
|
508
|
+
Examples
|
|
509
|
+
--------
|
|
510
|
+
# Case 1: Data already split, just create loader
|
|
511
|
+
>>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
|
|
512
|
+
>>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
|
|
513
|
+
|
|
514
|
+
# Case 2: Auto-split data into train/val/test
|
|
515
|
+
>>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
|
|
516
|
+
>>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
|
|
517
|
+
... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
|
|
469
518
|
"""
|
|
470
519
|
if split_ratio is None:
|
|
471
|
-
|
|
520
|
+
# No split - data is already divided, just create a single DataLoader
|
|
521
|
+
dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
522
|
+
return (dataloader,)
|
|
472
523
|
|
|
473
|
-
#
|
|
524
|
+
# Split data into train/val/test
|
|
474
525
|
assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
|
|
475
526
|
|
|
476
|
-
# 计算分割大小
|
|
477
527
|
total_size = len(self.dataset)
|
|
478
528
|
train_size = int(total_size * split_ratio[0])
|
|
479
529
|
val_size = int(total_size * split_ratio[1])
|
|
480
530
|
test_size = total_size - train_size - val_size
|
|
481
531
|
|
|
482
|
-
# 分割数据集
|
|
483
532
|
train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
|
|
484
533
|
|
|
485
|
-
# 创建数据加载器
|
|
486
534
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
487
|
-
|
|
488
535
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
489
|
-
|
|
490
536
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
491
537
|
|
|
492
538
|
return train_loader, val_loader, test_loader
|
torch_rechub/utils/hstu_utils.py
CHANGED
|
@@ -6,25 +6,27 @@ import torch.nn as nn
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class RelPosBias(nn.Module):
|
|
9
|
-
"""Relative position bias
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
Shape
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
9
|
+
"""Relative position bias for attention.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
n_heads : int
|
|
14
|
+
Number of attention heads.
|
|
15
|
+
max_seq_len : int
|
|
16
|
+
Maximum supported sequence length.
|
|
17
|
+
num_buckets : int, default=32
|
|
18
|
+
Number of relative position buckets.
|
|
19
|
+
|
|
20
|
+
Shape
|
|
21
|
+
-----
|
|
22
|
+
Output: ``(1, n_heads, seq_len, seq_len)``
|
|
23
|
+
|
|
24
|
+
Examples
|
|
25
|
+
--------
|
|
26
|
+
>>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
|
|
27
|
+
>>> bias = rel_pos_bias(256)
|
|
28
|
+
>>> bias.shape
|
|
29
|
+
torch.Size([1, 8, 256, 256])
|
|
28
30
|
"""
|
|
29
31
|
|
|
30
32
|
def __init__(self, n_heads, max_seq_len, num_buckets=32):
|
|
@@ -87,22 +89,20 @@ class RelPosBias(nn.Module):
|
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
class VocabMask(nn.Module):
|
|
90
|
-
"""Vocabulary mask
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
>>> logits = torch.randn(32, 1000)
|
|
105
|
-
>>> masked_logits = mask.apply_mask(logits)
|
|
92
|
+
"""Vocabulary mask to block invalid items at inference.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
vocab_size : int
|
|
97
|
+
Vocabulary size.
|
|
98
|
+
invalid_items : list, optional
|
|
99
|
+
IDs to mask out.
|
|
100
|
+
|
|
101
|
+
Examples
|
|
102
|
+
--------
|
|
103
|
+
>>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
|
|
104
|
+
>>> logits = torch.randn(32, 1000)
|
|
105
|
+
>>> masked_logits = mask.apply_mask(logits)
|
|
106
106
|
"""
|
|
107
107
|
|
|
108
108
|
def __init__(self, vocab_size, invalid_items=None):
|
|
@@ -123,13 +123,17 @@ class VocabMask(nn.Module):
|
|
|
123
123
|
self.mask[item_id] = False
|
|
124
124
|
|
|
125
125
|
def apply_mask(self, logits):
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
126
|
+
"""Apply mask to logits.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
logits : Tensor
|
|
131
|
+
Model logits, shape ``(..., vocab_size)``.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
Tensor
|
|
136
|
+
Masked logits.
|
|
133
137
|
"""
|
|
134
138
|
# 将无效item的logits设置为极小值
|
|
135
139
|
masked_logits = logits.clone()
|
|
@@ -139,26 +143,25 @@ class VocabMask(nn.Module):
|
|
|
139
143
|
|
|
140
144
|
|
|
141
145
|
class VocabMapper(object):
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
>>> decoded_ids = mapper.decode(token_ids)
|
|
146
|
+
"""Identity mapper between ``item_id`` and ``token_id``.
|
|
147
|
+
|
|
148
|
+
Useful for sequence generation where items are treated as tokens.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
vocab_size : int
|
|
153
|
+
Vocabulary size.
|
|
154
|
+
pad_id : int, default=0
|
|
155
|
+
PAD token id.
|
|
156
|
+
unk_id : int, default=1
|
|
157
|
+
Unknown token id.
|
|
158
|
+
|
|
159
|
+
Examples
|
|
160
|
+
--------
|
|
161
|
+
>>> mapper = VocabMapper(vocab_size=1000)
|
|
162
|
+
>>> item_ids = np.array([10, 20, 30])
|
|
163
|
+
>>> token_ids = mapper.encode(item_ids)
|
|
164
|
+
>>> decoded_ids = mapper.decode(token_ids)
|
|
162
165
|
"""
|
|
163
166
|
|
|
164
167
|
def __init__(self, vocab_size, pad_id=0, unk_id=1):
|
|
@@ -172,26 +175,34 @@ class VocabMapper(object):
|
|
|
172
175
|
self.token2item = np.arange(vocab_size)
|
|
173
176
|
|
|
174
177
|
def encode(self, item_ids):
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
178
|
+
"""Convert item_ids to token_ids.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
item_ids : np.ndarray
|
|
183
|
+
Item ids.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
np.ndarray
|
|
188
|
+
Token ids.
|
|
182
189
|
"""
|
|
183
190
|
# 处理超出范围的item_id
|
|
184
191
|
token_ids = np.where((item_ids >= 0) & (item_ids < self.vocab_size), item_ids, self.unk_id)
|
|
185
192
|
return token_ids
|
|
186
193
|
|
|
187
194
|
def decode(self, token_ids):
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
+
"""Convert token_ids back to item_ids.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
token_ids : np.ndarray
|
|
200
|
+
Token ids.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
np.ndarray
|
|
205
|
+
Item ids.
|
|
195
206
|
"""
|
|
196
207
|
# 处理超出范围的token_id
|
|
197
208
|
item_ids = np.where((token_ids >= 0) & (token_ids < self.vocab_size), token_ids, self.unk_id)
|