torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +213 -150
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +107 -0
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +52 -3
- torch_rechub/trainers/match_trainer.py +52 -3
- torch_rechub/trainers/mtl_trainer.py +61 -3
- torch_rechub/trainers/seq_trainer.py +93 -17
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +167 -137
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/METADATA +20 -5
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/RECORD +27 -17
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/utils/data.py
CHANGED
|
@@ -82,57 +82,67 @@ class DataGenerator(object):
|
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
def get_auto_embedding_dim(num_classes):
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
85
|
+
"""Calculate embedding dim by category size.
|
|
86
|
+
|
|
87
|
+
Uses ``emb_dim = floor(6 * num_classes**0.25)`` from DCN (ADKDD'17).
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
num_classes : int
|
|
92
|
+
Number of categorical classes.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
int
|
|
97
|
+
Recommended embedding dimension.
|
|
93
98
|
"""
|
|
94
99
|
return int(np.floor(6 * np.pow(num_classes, 0.25)))
|
|
95
100
|
|
|
96
101
|
|
|
97
102
|
def get_loss_func(task_type="classification"):
|
|
103
|
+
"""Return default loss by task type."""
|
|
98
104
|
if task_type == "classification":
|
|
99
105
|
return torch.nn.BCELoss()
|
|
100
|
-
|
|
106
|
+
if task_type == "regression":
|
|
101
107
|
return torch.nn.MSELoss()
|
|
102
|
-
|
|
103
|
-
raise ValueError("task_type must be classification or regression")
|
|
108
|
+
raise ValueError("task_type must be classification or regression")
|
|
104
109
|
|
|
105
110
|
|
|
106
111
|
def get_metric_func(task_type="classification"):
|
|
112
|
+
"""Return default metric by task type."""
|
|
107
113
|
if task_type == "classification":
|
|
108
114
|
return roc_auc_score
|
|
109
|
-
|
|
115
|
+
if task_type == "regression":
|
|
110
116
|
return mean_squared_error
|
|
111
|
-
|
|
112
|
-
raise ValueError("task_type must be classification or regression")
|
|
117
|
+
raise ValueError("task_type must be classification or regression")
|
|
113
118
|
|
|
114
119
|
|
|
115
120
|
def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols=[], min_item=0, shuffle=True, max_len=50):
|
|
116
|
-
"""
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
121
|
+
"""Generate sequence features and negatives for ranking.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
data : pd.DataFrame
|
|
126
|
+
Raw interaction data.
|
|
127
|
+
user_col : str
|
|
128
|
+
User id column name.
|
|
129
|
+
item_col : str
|
|
130
|
+
Item id column name.
|
|
131
|
+
time_col : str
|
|
132
|
+
Timestamp column name.
|
|
133
|
+
item_attribute_cols : list[str], optional
|
|
134
|
+
Additional item attribute columns to include in sequences.
|
|
135
|
+
min_item : int, default=0
|
|
136
|
+
Minimum items per user; users below are dropped.
|
|
137
|
+
shuffle : bool, default=True
|
|
138
|
+
Shuffle train/val/test.
|
|
139
|
+
max_len : int, default=50
|
|
140
|
+
Max history length.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
|
145
|
+
Train, validation, and test data with sequence features.
|
|
136
146
|
"""
|
|
137
147
|
for feat in data:
|
|
138
148
|
le = LabelEncoder()
|
|
@@ -205,12 +215,17 @@ def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols
|
|
|
205
215
|
|
|
206
216
|
|
|
207
217
|
def df_to_dict(data):
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
218
|
+
"""Convert DataFrame to dict inputs accepted by models.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
data : pd.DataFrame
|
|
223
|
+
Input dataframe.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
dict
|
|
228
|
+
Mapping of column name to numpy array.
|
|
214
229
|
"""
|
|
215
230
|
data_dict = data.to_dict('list')
|
|
216
231
|
for key in data.keys():
|
|
@@ -226,20 +241,28 @@ def neg_sample(click_hist, item_size):
|
|
|
226
241
|
|
|
227
242
|
|
|
228
243
|
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
|
|
229
|
-
"""
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
244
|
+
"""Pad list-of-lists sequences to equal length.
|
|
245
|
+
|
|
246
|
+
Equivalent to ``tf.keras.preprocessing.sequence.pad_sequences``.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
sequences : Sequence[Sequence]
|
|
251
|
+
Input sequences.
|
|
252
|
+
maxlen : int, optional
|
|
253
|
+
Maximum length; computed if None.
|
|
254
|
+
dtype : str, default='int32'
|
|
255
|
+
padding : {'pre', 'post'}, default='pre'
|
|
256
|
+
Padding direction.
|
|
257
|
+
truncating : {'pre', 'post'}, default='pre'
|
|
258
|
+
Truncation direction.
|
|
259
|
+
value : float, default=0.0
|
|
260
|
+
Padding value.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
np.ndarray
|
|
265
|
+
Padded array of shape (n_samples, maxlen).
|
|
243
266
|
"""
|
|
244
267
|
|
|
245
268
|
assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
|
|
@@ -265,13 +288,19 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncati
|
|
|
265
288
|
|
|
266
289
|
|
|
267
290
|
def array_replace_with_dict(array, dic):
|
|
268
|
-
"""Replace values in
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
291
|
+
"""Replace values in numpy array using a mapping dict.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
array : np.ndarray
|
|
296
|
+
Input array.
|
|
297
|
+
dic : dict
|
|
298
|
+
Mapping from old to new values.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
np.ndarray
|
|
303
|
+
Array with values replaced.
|
|
275
304
|
"""
|
|
276
305
|
# Extract out keys and values
|
|
277
306
|
k = np.array(list(dic.keys()))
|
|
@@ -284,19 +313,25 @@ def array_replace_with_dict(array, dic):
|
|
|
284
313
|
|
|
285
314
|
# Temporarily reserved for testing purposes(1985312383@qq.com)
|
|
286
315
|
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
|
|
287
|
-
"""Build
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
316
|
+
"""Build user history sequences by time.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
data : pd.DataFrame
|
|
321
|
+
Must contain ``user_id, item_id, cate_id, time``.
|
|
322
|
+
seq_feature_col : list, default ['item_id', 'cate_id']
|
|
323
|
+
Columns to generate sequence features.
|
|
324
|
+
max_len : int, default=50
|
|
325
|
+
Max history length.
|
|
326
|
+
drop_short : int, default=3
|
|
327
|
+
Drop users with sequence length < drop_short.
|
|
328
|
+
shuffle : bool, default=True
|
|
329
|
+
Shuffle outputs.
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
|
334
|
+
Train/val/test splits with sequence features.
|
|
300
335
|
"""
|
|
301
336
|
for feat in data:
|
|
302
337
|
le = LabelEncoder()
|
|
@@ -357,30 +392,32 @@ def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50
|
|
|
357
392
|
|
|
358
393
|
|
|
359
394
|
class SeqDataset(Dataset):
|
|
360
|
-
"""Sequence dataset for HSTU-style
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
Shape
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
395
|
+
"""Sequence dataset for HSTU-style next-item prediction.
|
|
396
|
+
|
|
397
|
+
Parameters
|
|
398
|
+
----------
|
|
399
|
+
seq_tokens : np.ndarray
|
|
400
|
+
Token ids, shape ``(num_samples, seq_len)``.
|
|
401
|
+
seq_positions : np.ndarray
|
|
402
|
+
Position indices, shape ``(num_samples, seq_len)``.
|
|
403
|
+
targets : np.ndarray
|
|
404
|
+
Target token ids, shape ``(num_samples,)``.
|
|
405
|
+
seq_time_diffs : np.ndarray
|
|
406
|
+
Time-difference features, shape ``(num_samples, seq_len)``.
|
|
407
|
+
|
|
408
|
+
Shape
|
|
409
|
+
-----
|
|
410
|
+
Output tuple: ``(seq_tokens, seq_positions, seq_time_diffs, target)``
|
|
411
|
+
|
|
412
|
+
Examples
|
|
413
|
+
--------
|
|
414
|
+
>>> seq_tokens = np.random.randint(0, 1000, (100, 256))
|
|
415
|
+
>>> seq_positions = np.arange(256)[np.newaxis, :].repeat(100, axis=0)
|
|
416
|
+
>>> seq_time_diffs = np.random.randint(0, 86400, (100, 256))
|
|
417
|
+
>>> targets = np.random.randint(0, 1000, (100,))
|
|
418
|
+
>>> dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
419
|
+
>>> len(dataset)
|
|
420
|
+
100
|
|
384
421
|
"""
|
|
385
422
|
|
|
386
423
|
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
@@ -414,29 +451,25 @@ class SeqDataset(Dataset):
|
|
|
414
451
|
|
|
415
452
|
|
|
416
453
|
class SequenceDataGenerator(object):
|
|
417
|
-
"""Sequence data generator
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
>>> seq_time_diffs = np.random.randint(0, 86400, (1000, 256))
|
|
437
|
-
>>> targets = np.random.randint(0, 1000, (1000,))
|
|
438
|
-
>>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
439
|
-
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
|
|
454
|
+
"""Sequence data generator for HSTU-style models.
|
|
455
|
+
|
|
456
|
+
Wraps :class:`SeqDataset` and builds train/val/test loaders.
|
|
457
|
+
|
|
458
|
+
Parameters
|
|
459
|
+
----------
|
|
460
|
+
seq_tokens : np.ndarray
|
|
461
|
+
Token ids, shape ``(num_samples, seq_len)``.
|
|
462
|
+
seq_positions : np.ndarray
|
|
463
|
+
Position indices, shape ``(num_samples, seq_len)``.
|
|
464
|
+
targets : np.ndarray
|
|
465
|
+
Target token ids, shape ``(num_samples,)``.
|
|
466
|
+
seq_time_diffs : np.ndarray
|
|
467
|
+
Time-difference features, shape ``(num_samples, seq_len)``.
|
|
468
|
+
|
|
469
|
+
Examples
|
|
470
|
+
--------
|
|
471
|
+
>>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
472
|
+
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
|
|
440
473
|
"""
|
|
441
474
|
|
|
442
475
|
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
@@ -450,22 +483,19 @@ class SequenceDataGenerator(object):
|
|
|
450
483
|
self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
451
484
|
|
|
452
485
|
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
|
|
453
|
-
"""
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
... num_workers=4,
|
|
467
|
-
... split_ratio=(0.7, 0.1, 0.2)
|
|
468
|
-
... )
|
|
486
|
+
"""Generate train/val/test dataloaders.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
batch_size : int, default=32
|
|
491
|
+
num_workers : int, default=0
|
|
492
|
+
split_ratio : tuple, default (0.7, 0.1, 0.2)
|
|
493
|
+
Train/val/test split.
|
|
494
|
+
|
|
495
|
+
Returns
|
|
496
|
+
-------
|
|
497
|
+
tuple
|
|
498
|
+
(train_loader, val_loader, test_loader)
|
|
469
499
|
"""
|
|
470
500
|
if split_ratio is None:
|
|
471
501
|
split_ratio = (0.7, 0.1, 0.2)
|
torch_rechub/utils/hstu_utils.py
CHANGED
|
@@ -6,25 +6,27 @@ import torch.nn as nn
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class RelPosBias(nn.Module):
|
|
9
|
-
"""Relative position bias
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
Shape
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
9
|
+
"""Relative position bias for attention.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
n_heads : int
|
|
14
|
+
Number of attention heads.
|
|
15
|
+
max_seq_len : int
|
|
16
|
+
Maximum supported sequence length.
|
|
17
|
+
num_buckets : int, default=32
|
|
18
|
+
Number of relative position buckets.
|
|
19
|
+
|
|
20
|
+
Shape
|
|
21
|
+
-----
|
|
22
|
+
Output: ``(1, n_heads, seq_len, seq_len)``
|
|
23
|
+
|
|
24
|
+
Examples
|
|
25
|
+
--------
|
|
26
|
+
>>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
|
|
27
|
+
>>> bias = rel_pos_bias(256)
|
|
28
|
+
>>> bias.shape
|
|
29
|
+
torch.Size([1, 8, 256, 256])
|
|
28
30
|
"""
|
|
29
31
|
|
|
30
32
|
def __init__(self, n_heads, max_seq_len, num_buckets=32):
|
|
@@ -87,22 +89,20 @@ class RelPosBias(nn.Module):
|
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
class VocabMask(nn.Module):
|
|
90
|
-
"""Vocabulary mask
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
>>> logits = torch.randn(32, 1000)
|
|
105
|
-
>>> masked_logits = mask.apply_mask(logits)
|
|
92
|
+
"""Vocabulary mask to block invalid items at inference.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
vocab_size : int
|
|
97
|
+
Vocabulary size.
|
|
98
|
+
invalid_items : list, optional
|
|
99
|
+
IDs to mask out.
|
|
100
|
+
|
|
101
|
+
Examples
|
|
102
|
+
--------
|
|
103
|
+
>>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
|
|
104
|
+
>>> logits = torch.randn(32, 1000)
|
|
105
|
+
>>> masked_logits = mask.apply_mask(logits)
|
|
106
106
|
"""
|
|
107
107
|
|
|
108
108
|
def __init__(self, vocab_size, invalid_items=None):
|
|
@@ -123,13 +123,17 @@ class VocabMask(nn.Module):
|
|
|
123
123
|
self.mask[item_id] = False
|
|
124
124
|
|
|
125
125
|
def apply_mask(self, logits):
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
126
|
+
"""Apply mask to logits.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
logits : Tensor
|
|
131
|
+
Model logits, shape ``(..., vocab_size)``.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
Tensor
|
|
136
|
+
Masked logits.
|
|
133
137
|
"""
|
|
134
138
|
# 将无效item的logits设置为极小值
|
|
135
139
|
masked_logits = logits.clone()
|
|
@@ -139,26 +143,25 @@ class VocabMask(nn.Module):
|
|
|
139
143
|
|
|
140
144
|
|
|
141
145
|
class VocabMapper(object):
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
>>> decoded_ids = mapper.decode(token_ids)
|
|
146
|
+
"""Identity mapper between ``item_id`` and ``token_id``.
|
|
147
|
+
|
|
148
|
+
Useful for sequence generation where items are treated as tokens.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
vocab_size : int
|
|
153
|
+
Vocabulary size.
|
|
154
|
+
pad_id : int, default=0
|
|
155
|
+
PAD token id.
|
|
156
|
+
unk_id : int, default=1
|
|
157
|
+
Unknown token id.
|
|
158
|
+
|
|
159
|
+
Examples
|
|
160
|
+
--------
|
|
161
|
+
>>> mapper = VocabMapper(vocab_size=1000)
|
|
162
|
+
>>> item_ids = np.array([10, 20, 30])
|
|
163
|
+
>>> token_ids = mapper.encode(item_ids)
|
|
164
|
+
>>> decoded_ids = mapper.decode(token_ids)
|
|
162
165
|
"""
|
|
163
166
|
|
|
164
167
|
def __init__(self, vocab_size, pad_id=0, unk_id=1):
|
|
@@ -172,26 +175,34 @@ class VocabMapper(object):
|
|
|
172
175
|
self.token2item = np.arange(vocab_size)
|
|
173
176
|
|
|
174
177
|
def encode(self, item_ids):
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
178
|
+
"""Convert item_ids to token_ids.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
item_ids : np.ndarray
|
|
183
|
+
Item ids.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
np.ndarray
|
|
188
|
+
Token ids.
|
|
182
189
|
"""
|
|
183
190
|
# 处理超出范围的item_id
|
|
184
191
|
token_ids = np.where((item_ids >= 0) & (item_ids < self.vocab_size), item_ids, self.unk_id)
|
|
185
192
|
return token_ids
|
|
186
193
|
|
|
187
194
|
def decode(self, token_ids):
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
+
"""Convert token_ids back to item_ids.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
token_ids : np.ndarray
|
|
200
|
+
Token ids.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
np.ndarray
|
|
205
|
+
Item ids.
|
|
195
206
|
"""
|
|
196
207
|
# 处理超出范围的token_id
|
|
197
208
|
item_ids = np.where((token_ids >= 0) & (token_ids < self.vocab_size), token_ids, self.unk_id)
|
|
@@ -26,32 +26,30 @@ except ImportError:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
29
|
-
"""Extract feature information from a torch-rechub model
|
|
30
|
-
|
|
31
|
-
This function inspects model attributes to find feature lists without
|
|
32
|
-
modifying the model code. Supports various model architectures.
|
|
29
|
+
"""Extract feature information from a torch-rechub model via reflection.
|
|
33
30
|
|
|
34
31
|
Parameters
|
|
35
32
|
----------
|
|
36
33
|
model : nn.Module
|
|
37
|
-
|
|
34
|
+
Model to inspect.
|
|
38
35
|
|
|
39
36
|
Returns
|
|
40
37
|
-------
|
|
41
38
|
dict
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
39
|
+
{
|
|
40
|
+
'features': list of unique Feature objects,
|
|
41
|
+
'input_names': ordered feature names,
|
|
42
|
+
'input_types': map name -> feature type,
|
|
43
|
+
'user_features': user-side features (dual-tower),
|
|
44
|
+
'item_features': item-side features (dual-tower),
|
|
45
|
+
}
|
|
48
46
|
|
|
49
47
|
Examples
|
|
50
48
|
--------
|
|
51
49
|
>>> from torch_rechub.models.ranking import DeepFM
|
|
52
50
|
>>> model = DeepFM(deep_features, fm_features, mlp_params)
|
|
53
51
|
>>> info = extract_feature_info(model)
|
|
54
|
-
>>>
|
|
52
|
+
>>> info['input_names'] # ['user_id', 'item_id', ...]
|
|
55
53
|
"""
|
|
56
54
|
# Common feature attribute names across different model types
|
|
57
55
|
feature_attrs = [
|