torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
torch_rechub/__init__.py CHANGED
@@ -0,0 +1,14 @@
1
+ """Torch-RecHub: A PyTorch Toolbox for Recommendation Models."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ # 导入主要模块
6
+ from . import basic, models, trainers, utils
7
+
8
+ __all__ = [
9
+ "__version__",
10
+ "basic",
11
+ "models",
12
+ "trainers",
13
+ "utils",
14
+ ]
@@ -1,54 +1,54 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class Dice(nn.Module):
6
- """The Dice activation function mentioned in the `DIN paper
7
- https://arxiv.org/abs/1706.06978`
8
- """
9
-
10
- def __init__(self, epsilon=1e-3):
11
- super(Dice, self).__init__()
12
- self.epsilon = epsilon
13
- self.alpha = nn.Parameter(torch.randn(1))
14
-
15
- def forward(self, x: torch.Tensor):
16
- # x: N * num_neurons
17
- avg = x.mean(dim=1) # N
18
- avg = avg.unsqueeze(dim=1) # N * 1
19
- var = torch.pow(x - avg, 2) + self.epsilon # N * num_neurons
20
- var = var.sum(dim=1).unsqueeze(dim=1) # N * 1
21
-
22
- ps = (x - avg) / torch.sqrt(var) # N * 1
23
-
24
- ps = nn.Sigmoid()(ps) # N * 1
25
- return ps * x + (1 - ps) * self.alpha * x
26
-
27
-
28
- def activation_layer(act_name):
29
- """Construct activation layers
30
-
31
- Args:
32
- act_name: str or nn.Module, name of activation function
33
-
34
- Returns:
35
- act_layer: activation layer
36
- """
37
- if isinstance(act_name, str):
38
- if act_name.lower() == 'sigmoid':
39
- act_layer = nn.Sigmoid()
40
- elif act_name.lower() == 'relu':
41
- act_layer = nn.ReLU(inplace=True)
42
- elif act_name.lower() == 'dice':
43
- act_layer = Dice()
44
- elif act_name.lower() == 'prelu':
45
- act_layer = nn.PReLU()
46
- elif act_name.lower() == "softmax":
47
- act_layer = nn.Softmax(dim=1)
48
- elif act_name.lower() == 'leakyrelu':
49
- act_layer = nn.LeakyReLU()
50
- elif issubclass(act_name, nn.Module):
51
- act_layer = act_name()
52
- else:
53
- raise NotImplementedError
54
- return act_layer
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Dice(nn.Module):
6
+ """The Dice activation function mentioned in the `DIN paper
7
+ https://arxiv.org/abs/1706.06978`
8
+ """
9
+
10
+ def __init__(self, epsilon=1e-3):
11
+ super(Dice, self).__init__()
12
+ self.epsilon = epsilon
13
+ self.alpha = nn.Parameter(torch.randn(1))
14
+
15
+ def forward(self, x: torch.Tensor):
16
+ # x: N * num_neurons
17
+ avg = x.mean(dim=1) # N
18
+ avg = avg.unsqueeze(dim=1) # N * 1
19
+ var = torch.pow(x - avg, 2) + self.epsilon # N * num_neurons
20
+ var = var.sum(dim=1).unsqueeze(dim=1) # N * 1
21
+
22
+ ps = (x - avg) / torch.sqrt(var) # N * 1
23
+
24
+ ps = nn.Sigmoid()(ps) # N * 1
25
+ return ps * x + (1 - ps) * self.alpha * x
26
+
27
+
28
+ def activation_layer(act_name):
29
+ """Construct activation layers
30
+
31
+ Args:
32
+ act_name: str or nn.Module, name of activation function
33
+
34
+ Returns:
35
+ act_layer: activation layer
36
+ """
37
+ if isinstance(act_name, str):
38
+ if act_name.lower() == 'sigmoid':
39
+ act_layer = nn.Sigmoid()
40
+ elif act_name.lower() == 'relu':
41
+ act_layer = nn.ReLU(inplace=True)
42
+ elif act_name.lower() == 'dice':
43
+ act_layer = Dice()
44
+ elif act_name.lower() == 'prelu':
45
+ act_layer = nn.PReLU()
46
+ elif act_name.lower() == "softmax":
47
+ act_layer = nn.Softmax(dim=1)
48
+ elif act_name.lower() == 'leakyrelu':
49
+ act_layer = nn.LeakyReLU()
50
+ elif issubclass(act_name, nn.Module):
51
+ act_layer = act_name()
52
+ else:
53
+ raise NotImplementedError
54
+ return act_layer
@@ -1,33 +1,33 @@
1
- import copy
2
-
3
-
4
- class EarlyStopper(object):
5
- """Early stops the training if validation loss doesn't improve after a given patience.
6
-
7
- Args:
8
- patience (int): How long to wait after last time validation auc improved.
9
- """
10
-
11
- def __init__(self, patience):
12
- self.patience = patience
13
- self.trial_counter = 0
14
- self.best_auc = 0
15
- self.best_weights = None
16
-
17
- def stop_training(self, val_auc, weights):
18
- """whether to stop training.
19
-
20
- Args:
21
- val_auc (float): auc score in val data.
22
- weights (tensor): the weights of model
23
- """
24
- if val_auc > self.best_auc:
25
- self.best_auc = val_auc
26
- self.trial_counter = 0
27
- self.best_weights = copy.deepcopy(weights)
28
- return False
29
- elif self.trial_counter + 1 < self.patience:
30
- self.trial_counter += 1
31
- return False
32
- else:
33
- return True
1
+ import copy
2
+
3
+
4
+ class EarlyStopper(object):
5
+ """Early stops the training if validation loss doesn't improve after a given patience.
6
+
7
+ Args:
8
+ patience (int): How long to wait after last time validation auc improved.
9
+ """
10
+
11
+ def __init__(self, patience):
12
+ self.patience = patience
13
+ self.trial_counter = 0
14
+ self.best_auc = 0
15
+ self.best_weights = None
16
+
17
+ def stop_training(self, val_auc, weights):
18
+ """whether to stop training.
19
+
20
+ Args:
21
+ val_auc (float): auc score in val data.
22
+ weights (tensor): the weights of model
23
+ """
24
+ if val_auc > self.best_auc:
25
+ self.best_auc = val_auc
26
+ self.trial_counter = 0
27
+ self.best_weights = copy.deepcopy(weights)
28
+ return False
29
+ elif self.trial_counter + 1 < self.patience:
30
+ self.trial_counter += 1
31
+ return False
32
+ else:
33
+ return True
@@ -1,94 +1,87 @@
1
- from ..utils.data import get_auto_embedding_dim
2
- from .initializers import RandomNormal
3
-
4
-
5
- class SequenceFeature(object):
6
- """The Feature Class for Sequence feature or multi-hot feature.
7
- In recommendation, there are many user behaviour features which we want to take the sequence model
8
- and tag featurs (multi hot) which we want to pooling. Note that if you use this feature, you must padding
9
- the feature value before training.
10
-
11
- Args:
12
- name (str): feature's name.
13
- vocab_size (int): vocabulary size of embedding table.
14
- embed_dim (int): embedding vector's length
15
- pooling (str): pooling method, support `["mean", "sum", "concat"]` (default=`"mean"`)
16
- shared_with (str): the another feature name which this feature will shared with embedding.
17
- padding_idx (int, optional): If specified, the entries at padding_idx will be masked 0 in InputMask Layer.
18
- initializer(Initializer): Initializer the embedding layer weight.
19
- """
20
-
21
- def __init__(self,
22
- name,
23
- vocab_size,
24
- embed_dim=None,
25
- pooling="mean",
26
- shared_with=None,
27
- padding_idx=None,
28
- initializer=RandomNormal(0, 0.0001)):
29
- self.name = name
30
- self.vocab_size = vocab_size
31
- if embed_dim is None:
32
- self.embed_dim = get_auto_embedding_dim(vocab_size)
33
- else:
34
- self.embed_dim = embed_dim
35
- self.pooling = pooling
36
- self.shared_with = shared_with
37
- self.padding_idx = padding_idx
38
- self.initializer = initializer
39
-
40
- def __repr__(self):
41
- return f'<SequenceFeature {self.name} with Embedding shape ({self.vocab_size}, {self.embed_dim})>'
42
-
43
- def get_embedding_layer(self):
44
- if not hasattr(self, 'embed'):
45
- self.embed = self.initializer(self.vocab_size, self.embed_dim)
46
- return self.embed
47
-
48
-
49
- class SparseFeature(object):
50
- """The Feature Class for Sparse feature.
51
-
52
- Args:
53
- name (str): feature's name.
54
- vocab_size (int): vocabulary size of embedding table.
55
- embed_dim (int): embedding vector's length
56
- shared_with (str): the another feature name which this feature will shared with embedding.
57
- padding_idx (int, optional): If specified, the entries at padding_idx will be masked 0 in InputMask Layer.
58
- initializer(Initializer): Initializer the embedding layer weight.
59
- """
60
-
61
- def __init__(self, name, vocab_size, embed_dim=None, shared_with=None, padding_idx=None, initializer=RandomNormal(0, 0.0001)):
62
- self.name = name
63
- self.vocab_size = vocab_size
64
- if embed_dim is None:
65
- self.embed_dim = get_auto_embedding_dim(vocab_size)
66
- else:
67
- self.embed_dim = embed_dim
68
- self.shared_with = shared_with
69
- self.padding_idx = padding_idx
70
- self.initializer = initializer
71
-
72
- def __repr__(self):
73
- return f'<SparseFeature {self.name} with Embedding shape ({self.vocab_size}, {self.embed_dim})>'
74
-
75
- def get_embedding_layer(self):
76
- if not hasattr(self, 'embed'):
77
- self.embed = self.initializer(self.vocab_size, self.embed_dim)
78
- return self.embed
79
-
80
-
81
- class DenseFeature(object):
82
- """The Feature Class for Dense feature.
83
-
84
- Args:
85
- name (str): feature's name.
86
- embed_dim (int): embedding vector's length, the value fixed `1`. If you put a vector (torch.tensor) , replace the embed_dim with your vector dimension.
87
- """
88
-
89
- def __init__(self, name, embed_dim = 1):
90
- self.name = name
91
- self.embed_dim = embed_dim
92
-
93
- def __repr__(self):
94
- return f'<DenseFeature {self.name}>'
1
+ from ..utils.data import get_auto_embedding_dim
2
+ from .initializers import RandomNormal
3
+
4
+
5
+ class SequenceFeature(object):
6
+ """The Feature Class for Sequence feature or multi-hot feature.
7
+ In recommendation, there are many user behaviour features which we want to take the sequence model
8
+ and tag featurs (multi hot) which we want to pooling. Note that if you use this feature, you must padding
9
+ the feature value before training.
10
+
11
+ Args:
12
+ name (str): feature's name.
13
+ vocab_size (int): vocabulary size of embedding table.
14
+ embed_dim (int): embedding vector's length
15
+ pooling (str): pooling method, support `["mean", "sum", "concat"]` (default=`"mean"`)
16
+ shared_with (str): the another feature name which this feature will shared with embedding.
17
+ padding_idx (int, optional): If specified, the entries at padding_idx will be masked 0 in InputMask Layer.
18
+ initializer(Initializer): Initializer the embedding layer weight.
19
+ """
20
+
21
+ def __init__(self, name, vocab_size, embed_dim=None, pooling="mean", shared_with=None, padding_idx=None, initializer=RandomNormal(0, 0.0001)):
22
+ self.name = name
23
+ self.vocab_size = vocab_size
24
+ if embed_dim is None:
25
+ self.embed_dim = get_auto_embedding_dim(vocab_size)
26
+ else:
27
+ self.embed_dim = embed_dim
28
+ self.pooling = pooling
29
+ self.shared_with = shared_with
30
+ self.padding_idx = padding_idx
31
+ self.initializer = initializer
32
+
33
+ def __repr__(self):
34
+ return f'<SequenceFeature {self.name} with Embedding shape ({self.vocab_size}, {self.embed_dim})>'
35
+
36
+ def get_embedding_layer(self):
37
+ if not hasattr(self, 'embed'):
38
+ self.embed = self.initializer(self.vocab_size, self.embed_dim)
39
+ return self.embed
40
+
41
+
42
+ class SparseFeature(object):
43
+ """The Feature Class for Sparse feature.
44
+
45
+ Args:
46
+ name (str): feature's name.
47
+ vocab_size (int): vocabulary size of embedding table.
48
+ embed_dim (int): embedding vector's length
49
+ shared_with (str): the another feature name which this feature will shared with embedding.
50
+ padding_idx (int, optional): If specified, the entries at padding_idx will be masked 0 in InputMask Layer.
51
+ initializer(Initializer): Initializer the embedding layer weight.
52
+ """
53
+
54
+ def __init__(self, name, vocab_size, embed_dim=None, shared_with=None, padding_idx=None, initializer=RandomNormal(0, 0.0001)):
55
+ self.name = name
56
+ self.vocab_size = vocab_size
57
+ if embed_dim is None:
58
+ self.embed_dim = get_auto_embedding_dim(vocab_size)
59
+ else:
60
+ self.embed_dim = embed_dim
61
+ self.shared_with = shared_with
62
+ self.padding_idx = padding_idx
63
+ self.initializer = initializer
64
+
65
+ def __repr__(self):
66
+ return f'<SparseFeature {self.name} with Embedding shape ({self.vocab_size}, {self.embed_dim})>'
67
+
68
+ def get_embedding_layer(self):
69
+ if not hasattr(self, 'embed'):
70
+ self.embed = self.initializer(self.vocab_size, self.embed_dim)
71
+ return self.embed
72
+
73
+
74
+ class DenseFeature(object):
75
+ """The Feature Class for Dense feature.
76
+
77
+ Args:
78
+ name (str): feature's name.
79
+ embed_dim (int): embedding vector's length, the value fixed `1`. If you put a vector (torch.tensor) , replace the embed_dim with your vector dimension.
80
+ """
81
+
82
+ def __init__(self, name, embed_dim=1):
83
+ self.name = name
84
+ self.embed_dim = embed_dim
85
+
86
+ def __repr__(self):
87
+ return f'<DenseFeature {self.name}>'
@@ -1,92 +1,92 @@
1
- import torch
2
-
3
-
4
- class RandomNormal(object):
5
- """Returns an embedding initialized with a normal distribution.
6
-
7
- Args:
8
- mean (float): the mean of the normal distribution
9
- std (float): the standard deviation of the normal distribution
10
- """
11
-
12
- def __init__(self, mean=0.0, std=1.0):
13
- self.mean = mean
14
- self.std = std
15
-
16
- def __call__(self, vocab_size, embed_dim):
17
- embed = torch.nn.Embedding(vocab_size, embed_dim)
18
- torch.nn.init.normal_(embed.weight, self.mean, self.std)
19
- return embed
20
-
21
-
22
- class RandomUniform(object):
23
- """Returns an embedding initialized with a uniform distribution.
24
-
25
- Args:
26
- minval (float): Lower bound of the range of random values of the uniform distribution.
27
- maxval (float): Upper bound of the range of random values of the uniform distribution.
28
- """
29
-
30
- def __init__(self, minval=0.0, maxval=1.0):
31
- self.minval = minval
32
- self.maxval = maxval
33
-
34
- def __call__(self, vocab_size, embed_dim):
35
- embed = torch.nn.Embedding(vocab_size, embed_dim)
36
- torch.nn.init.uniform_(embed.weight, self.minval, self.maxval)
37
- return embed
38
-
39
-
40
- class XavierNormal(object):
41
- """Returns an embedding initialized with the method described in
42
- `Understanding the difficulty of training deep feedforward neural networks`
43
- - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.
44
-
45
- Args:
46
- gain (float): stddev = gain*sqrt(2 / (fan_in + fan_out))
47
- """
48
-
49
- def __init__(self, gain=1.0):
50
- self.gain = gain
51
-
52
- def __call__(self, vocab_size, embed_dim):
53
- embed = torch.nn.Embedding(vocab_size, embed_dim)
54
- torch.nn.init.xavier_normal_(embed.weight, self.gain)
55
- return embed
56
-
57
-
58
- class XavierUniform(object):
59
- """Returns an embedding initialized with the method described in
60
- `Understanding the difficulty of training deep feedforward neural networks`
61
- - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.
62
-
63
- Args:
64
- gain (float): stddev = gain*sqrt(6 / (fan_in + fan_out))
65
- """
66
-
67
- def __init__(self, gain=1.0):
68
- self.gain = gain
69
-
70
- def __call__(self, vocab_size, embed_dim):
71
- embed = torch.nn.Embedding(vocab_size, embed_dim)
72
- torch.nn.init.xavier_uniform_(embed.weight, self.gain)
73
- return embed
74
-
75
-
76
- class Pretrained(object):
77
- """Creates Embedding instance from given 2-dimensional FloatTensor.
78
-
79
- Args:
80
- embedding_weight(Tensor or ndarray or List[List[int]]): FloatTensor containing weights for the Embedding.
81
- First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
82
- freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
83
- """
84
-
85
- def __init__(self, embedding_weight, freeze=True):
86
- self.embedding_weight = torch.FloatTensor(embedding_weight)
87
- self.freeze = freeze
88
-
89
- def __call__(self, vocab_size, embed_dim):
90
- assert vocab_size == self.embedding_weight.shape[0] and embed_dim == self.embedding_weight.shape[1]
91
- embed = torch.nn.Embedding.from_pretrained(self.embedding_weight, freeze=self.freeze)
92
- return embed
1
+ import torch
2
+
3
+
4
+ class RandomNormal(object):
5
+ """Returns an embedding initialized with a normal distribution.
6
+
7
+ Args:
8
+ mean (float): the mean of the normal distribution
9
+ std (float): the standard deviation of the normal distribution
10
+ """
11
+
12
+ def __init__(self, mean=0.0, std=1.0):
13
+ self.mean = mean
14
+ self.std = std
15
+
16
+ def __call__(self, vocab_size, embed_dim):
17
+ embed = torch.nn.Embedding(vocab_size, embed_dim)
18
+ torch.nn.init.normal_(embed.weight, self.mean, self.std)
19
+ return embed
20
+
21
+
22
+ class RandomUniform(object):
23
+ """Returns an embedding initialized with a uniform distribution.
24
+
25
+ Args:
26
+ minval (float): Lower bound of the range of random values of the uniform distribution.
27
+ maxval (float): Upper bound of the range of random values of the uniform distribution.
28
+ """
29
+
30
+ def __init__(self, minval=0.0, maxval=1.0):
31
+ self.minval = minval
32
+ self.maxval = maxval
33
+
34
+ def __call__(self, vocab_size, embed_dim):
35
+ embed = torch.nn.Embedding(vocab_size, embed_dim)
36
+ torch.nn.init.uniform_(embed.weight, self.minval, self.maxval)
37
+ return embed
38
+
39
+
40
+ class XavierNormal(object):
41
+ """Returns an embedding initialized with the method described in
42
+ `Understanding the difficulty of training deep feedforward neural networks`
43
+ - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.
44
+
45
+ Args:
46
+ gain (float): stddev = gain*sqrt(2 / (fan_in + fan_out))
47
+ """
48
+
49
+ def __init__(self, gain=1.0):
50
+ self.gain = gain
51
+
52
+ def __call__(self, vocab_size, embed_dim):
53
+ embed = torch.nn.Embedding(vocab_size, embed_dim)
54
+ torch.nn.init.xavier_normal_(embed.weight, self.gain)
55
+ return embed
56
+
57
+
58
+ class XavierUniform(object):
59
+ """Returns an embedding initialized with the method described in
60
+ `Understanding the difficulty of training deep feedforward neural networks`
61
+ - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.
62
+
63
+ Args:
64
+ gain (float): stddev = gain*sqrt(6 / (fan_in + fan_out))
65
+ """
66
+
67
+ def __init__(self, gain=1.0):
68
+ self.gain = gain
69
+
70
+ def __call__(self, vocab_size, embed_dim):
71
+ embed = torch.nn.Embedding(vocab_size, embed_dim)
72
+ torch.nn.init.xavier_uniform_(embed.weight, self.gain)
73
+ return embed
74
+
75
+
76
+ class Pretrained(object):
77
+ """Creates Embedding instance from given 2-dimensional FloatTensor.
78
+
79
+ Args:
80
+ embedding_weight(Tensor or ndarray or List[List[int]]): FloatTensor containing weights for the Embedding.
81
+ First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
82
+ freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
83
+ """
84
+
85
+ def __init__(self, embedding_weight, freeze=True):
86
+ self.embedding_weight = torch.FloatTensor(embedding_weight)
87
+ self.freeze = freeze
88
+
89
+ def __call__(self, vocab_size, embed_dim):
90
+ assert vocab_size == self.embedding_weight.shape[0] and embed_dim == self.embedding_weight.shape[1]
91
+ embed = torch.nn.Embedding.from_pretrained(self.embedding_weight, freeze=self.freeze)
92
+ return embed