SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
- SinaTools-1.0.1.dist-info/RECORD +73 -0
- sinatools/VERSION +1 -1
- sinatools/ner/__init__.py +5 -7
- sinatools/ner/trainers/BertNestedTrainer.py +203 -203
- sinatools/ner/trainers/BertTrainer.py +163 -163
- sinatools/ner/trainers/__init__.py +2 -2
- SinaTools-0.1.40.dist-info/RECORD +0 -123
- sinatools/arabert/arabert/__init__.py +0 -14
- sinatools/arabert/arabert/create_classification_data.py +0 -260
- sinatools/arabert/arabert/create_pretraining_data.py +0 -534
- sinatools/arabert/arabert/extract_features.py +0 -444
- sinatools/arabert/arabert/lamb_optimizer.py +0 -158
- sinatools/arabert/arabert/modeling.py +0 -1027
- sinatools/arabert/arabert/optimization.py +0 -202
- sinatools/arabert/arabert/run_classifier.py +0 -1078
- sinatools/arabert/arabert/run_pretraining.py +0 -593
- sinatools/arabert/arabert/run_squad.py +0 -1440
- sinatools/arabert/arabert/tokenization.py +0 -414
- sinatools/arabert/araelectra/__init__.py +0 -1
- sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
- sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
- sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
- sinatools/arabert/araelectra/configure_finetuning.py +0 -172
- sinatools/arabert/araelectra/configure_pretraining.py +0 -143
- sinatools/arabert/araelectra/finetune/__init__.py +0 -14
- sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
- sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
- sinatools/arabert/araelectra/finetune/scorer.py +0 -54
- sinatools/arabert/araelectra/finetune/task.py +0 -74
- sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
- sinatools/arabert/araelectra/flops_computation.py +0 -215
- sinatools/arabert/araelectra/model/__init__.py +0 -14
- sinatools/arabert/araelectra/model/modeling.py +0 -1029
- sinatools/arabert/araelectra/model/optimization.py +0 -193
- sinatools/arabert/araelectra/model/tokenization.py +0 -355
- sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
- sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
- sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
- sinatools/arabert/araelectra/run_finetuning.py +0 -323
- sinatools/arabert/araelectra/run_pretraining.py +0 -469
- sinatools/arabert/araelectra/util/__init__.py +0 -14
- sinatools/arabert/araelectra/util/training_utils.py +0 -112
- sinatools/arabert/araelectra/util/utils.py +0 -109
- sinatools/arabert/aragpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
- sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
- sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
- sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
- sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
- sinatools/arabert/aragpt2/grover/__init__.py +0 -0
- sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
- sinatools/arabert/aragpt2/grover/modeling.py +0 -803
- sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
- sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
- sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
- sinatools/arabert/aragpt2/grover/utils.py +0 -234
- sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
- {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
- {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,160 +0,0 @@
|
|
1
|
-
# coding=utf-8
|
2
|
-
# Copyright 2020 The Google Research Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
"""Helpers for preparing pre-training data and supplying them to the model."""
|
17
|
-
|
18
|
-
from __future__ import absolute_import
|
19
|
-
from __future__ import division
|
20
|
-
from __future__ import print_function
|
21
|
-
|
22
|
-
import collections
|
23
|
-
|
24
|
-
import numpy as np
|
25
|
-
import tensorflow as tf
|
26
|
-
|
27
|
-
import configure_pretraining
|
28
|
-
from model import tokenization
|
29
|
-
from util import utils
|
30
|
-
|
31
|
-
|
32
|
-
def get_input_fn(config: configure_pretraining.PretrainingConfig, is_training,
|
33
|
-
num_cpu_threads=4):
|
34
|
-
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
35
|
-
|
36
|
-
input_files = []
|
37
|
-
for input_pattern in config.pretrain_tfrecords.split(","):
|
38
|
-
input_files.extend(tf.io.gfile.glob(input_pattern))
|
39
|
-
|
40
|
-
def input_fn(params):
|
41
|
-
"""The actual input function."""
|
42
|
-
batch_size = params["batch_size"]
|
43
|
-
|
44
|
-
name_to_features = {
|
45
|
-
"input_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64),
|
46
|
-
"input_mask": tf.io.FixedLenFeature([config.max_seq_length], tf.int64),
|
47
|
-
"segment_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64),
|
48
|
-
}
|
49
|
-
|
50
|
-
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
51
|
-
d = d.repeat()
|
52
|
-
d = d.shuffle(buffer_size=len(input_files))
|
53
|
-
|
54
|
-
# `cycle_length` is the number of parallel files that get read.
|
55
|
-
cycle_length = min(num_cpu_threads, len(input_files))
|
56
|
-
|
57
|
-
# `sloppy` mode means that the interleaving is not exact. This adds
|
58
|
-
# even more randomness to the training pipeline.
|
59
|
-
d = d.apply(
|
60
|
-
tf.data.experimental.parallel_interleave(
|
61
|
-
tf.data.TFRecordDataset,
|
62
|
-
sloppy=is_training,
|
63
|
-
cycle_length=cycle_length))
|
64
|
-
d = d.shuffle(buffer_size=100)
|
65
|
-
|
66
|
-
# We must `drop_remainder` on training because the TPU requires fixed
|
67
|
-
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
68
|
-
# and we *don"t* want to drop the remainder, otherwise we wont cover
|
69
|
-
# every sample.
|
70
|
-
d = d.apply(
|
71
|
-
tf.data.experimental.map_and_batch(
|
72
|
-
lambda record: _decode_record(record, name_to_features),
|
73
|
-
batch_size=batch_size,
|
74
|
-
num_parallel_batches=num_cpu_threads,
|
75
|
-
drop_remainder=True))
|
76
|
-
return d
|
77
|
-
|
78
|
-
return input_fn
|
79
|
-
|
80
|
-
|
81
|
-
def _decode_record(record, name_to_features):
|
82
|
-
"""Decodes a record to a TensorFlow example."""
|
83
|
-
example = tf.io.parse_single_example(record, name_to_features)
|
84
|
-
|
85
|
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
86
|
-
# So cast all int64 to int32.
|
87
|
-
for name in list(example.keys()):
|
88
|
-
t = example[name]
|
89
|
-
if t.dtype == tf.int64:
|
90
|
-
t = tf.cast(t, tf.int32)
|
91
|
-
example[name] = t
|
92
|
-
|
93
|
-
return example
|
94
|
-
|
95
|
-
|
96
|
-
# model inputs - it's a bit nicer to use a namedtuple rather than keep the
|
97
|
-
# features as a dict
|
98
|
-
Inputs = collections.namedtuple(
|
99
|
-
"Inputs", ["input_ids", "input_mask", "segment_ids", "masked_lm_positions",
|
100
|
-
"masked_lm_ids", "masked_lm_weights"])
|
101
|
-
|
102
|
-
|
103
|
-
def features_to_inputs(features):
|
104
|
-
return Inputs(
|
105
|
-
input_ids=features["input_ids"],
|
106
|
-
input_mask=features["input_mask"],
|
107
|
-
segment_ids=features["segment_ids"],
|
108
|
-
masked_lm_positions=(features["masked_lm_positions"]
|
109
|
-
if "masked_lm_positions" in features else None),
|
110
|
-
masked_lm_ids=(features["masked_lm_ids"]
|
111
|
-
if "masked_lm_ids" in features else None),
|
112
|
-
masked_lm_weights=(features["masked_lm_weights"]
|
113
|
-
if "masked_lm_weights" in features else None),
|
114
|
-
)
|
115
|
-
|
116
|
-
|
117
|
-
def get_updated_inputs(inputs, **kwargs):
|
118
|
-
features = inputs._asdict()
|
119
|
-
for k, v in kwargs.items():
|
120
|
-
features[k] = v
|
121
|
-
return features_to_inputs(features)
|
122
|
-
|
123
|
-
|
124
|
-
ENDC = "\033[0m"
|
125
|
-
COLORS = ["\033[" + str(n) + "m" for n in list(range(91, 97)) + [90]]
|
126
|
-
RED = COLORS[0]
|
127
|
-
BLUE = COLORS[3]
|
128
|
-
CYAN = COLORS[5]
|
129
|
-
GREEN = COLORS[1]
|
130
|
-
|
131
|
-
|
132
|
-
def print_tokens(inputs: Inputs, inv_vocab, updates_mask=None):
|
133
|
-
"""Pretty-print model inputs."""
|
134
|
-
pos_to_tokid = {}
|
135
|
-
for tokid, pos, weight in zip(
|
136
|
-
inputs.masked_lm_ids[0], inputs.masked_lm_positions[0],
|
137
|
-
inputs.masked_lm_weights[0]):
|
138
|
-
if weight == 0:
|
139
|
-
pass
|
140
|
-
else:
|
141
|
-
pos_to_tokid[pos] = tokid
|
142
|
-
|
143
|
-
text = ""
|
144
|
-
provided_update_mask = (updates_mask is not None)
|
145
|
-
if not provided_update_mask:
|
146
|
-
updates_mask = np.zeros_like(inputs.input_ids)
|
147
|
-
for pos, (tokid, um) in enumerate(
|
148
|
-
zip(inputs.input_ids[0], updates_mask[0])):
|
149
|
-
token = inv_vocab[tokid]
|
150
|
-
if token == "[PAD]":
|
151
|
-
break
|
152
|
-
if pos in pos_to_tokid:
|
153
|
-
token = RED + token + " (" + inv_vocab[pos_to_tokid[pos]] + ")" + ENDC
|
154
|
-
if provided_update_mask:
|
155
|
-
assert um == 1
|
156
|
-
else:
|
157
|
-
if provided_update_mask:
|
158
|
-
assert um == 0
|
159
|
-
text += token + " "
|
160
|
-
utils.log(tokenization.printable_text(text))
|
@@ -1,229 +0,0 @@
|
|
1
|
-
# coding=utf-8
|
2
|
-
# Copyright 2020 The Google Research Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
|
16
|
-
"""Helper functions for pre-training. These mainly deal with the gathering and
|
17
|
-
scattering needed so the generator only makes predictions for the small number
|
18
|
-
of masked tokens.
|
19
|
-
"""
|
20
|
-
|
21
|
-
from __future__ import absolute_import
|
22
|
-
from __future__ import division
|
23
|
-
from __future__ import print_function
|
24
|
-
|
25
|
-
import tensorflow as tf
|
26
|
-
|
27
|
-
import configure_pretraining
|
28
|
-
from model import modeling
|
29
|
-
from model import tokenization
|
30
|
-
from pretrain import pretrain_data
|
31
|
-
|
32
|
-
|
33
|
-
def gather_positions(sequence, positions):
|
34
|
-
"""Gathers the vectors at the specific positions over a minibatch.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
sequence: A [batch_size, seq_length] or
|
38
|
-
[batch_size, seq_length, depth] tensor of values
|
39
|
-
positions: A [batch_size, n_positions] tensor of indices
|
40
|
-
|
41
|
-
Returns: A [batch_size, n_positions] or
|
42
|
-
[batch_size, n_positions, depth] tensor of the values at the indices
|
43
|
-
"""
|
44
|
-
shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
|
45
|
-
depth_dimension = (len(shape) == 3)
|
46
|
-
if depth_dimension:
|
47
|
-
B, L, D = shape
|
48
|
-
else:
|
49
|
-
B, L = shape
|
50
|
-
D = 1
|
51
|
-
sequence = tf.expand_dims(sequence, -1)
|
52
|
-
position_shift = tf.expand_dims(L * tf.range(B), -1)
|
53
|
-
flat_positions = tf.reshape(positions + position_shift, [-1])
|
54
|
-
flat_sequence = tf.reshape(sequence, [B * L, D])
|
55
|
-
gathered = tf.gather(flat_sequence, flat_positions)
|
56
|
-
if depth_dimension:
|
57
|
-
return tf.reshape(gathered, [B, -1, D])
|
58
|
-
else:
|
59
|
-
return tf.reshape(gathered, [B, -1])
|
60
|
-
|
61
|
-
|
62
|
-
def scatter_update(sequence, updates, positions):
|
63
|
-
"""Scatter-update a sequence.
|
64
|
-
|
65
|
-
Args:
|
66
|
-
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
|
67
|
-
updates: A tensor of size batch_size*seq_len(*depth)
|
68
|
-
positions: A [batch_size, n_positions] tensor
|
69
|
-
|
70
|
-
Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
|
71
|
-
[batch_size, seq_len, depth] tensor of "sequence" with elements at
|
72
|
-
"positions" replaced by the values at "updates." Updates to index 0 are
|
73
|
-
ignored. If there are duplicated positions the update is only applied once.
|
74
|
-
Second is a [batch_size, seq_len] mask tensor of which inputs were updated.
|
75
|
-
"""
|
76
|
-
shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
|
77
|
-
depth_dimension = (len(shape) == 3)
|
78
|
-
if depth_dimension:
|
79
|
-
B, L, D = shape
|
80
|
-
else:
|
81
|
-
B, L = shape
|
82
|
-
D = 1
|
83
|
-
sequence = tf.expand_dims(sequence, -1)
|
84
|
-
N = modeling.get_shape_list(positions)[1]
|
85
|
-
|
86
|
-
shift = tf.expand_dims(L * tf.range(B), -1)
|
87
|
-
flat_positions = tf.reshape(positions + shift, [-1, 1])
|
88
|
-
flat_updates = tf.reshape(updates, [-1, D])
|
89
|
-
updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
|
90
|
-
updates = tf.reshape(updates, [B, L, D])
|
91
|
-
|
92
|
-
flat_updates_mask = tf.ones([B * N], tf.int32)
|
93
|
-
updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
|
94
|
-
updates_mask = tf.reshape(updates_mask, [B, L])
|
95
|
-
not_first_token = tf.concat([tf.zeros((B, 1), tf.int32),
|
96
|
-
tf.ones((B, L - 1), tf.int32)], -1)
|
97
|
-
updates_mask *= not_first_token
|
98
|
-
updates_mask_3d = tf.expand_dims(updates_mask, -1)
|
99
|
-
|
100
|
-
# account for duplicate positions
|
101
|
-
if sequence.dtype == tf.float32:
|
102
|
-
updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
|
103
|
-
updates /= tf.maximum(1.0, updates_mask_3d)
|
104
|
-
else:
|
105
|
-
assert sequence.dtype == tf.int32
|
106
|
-
updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
|
107
|
-
updates_mask = tf.minimum(updates_mask, 1)
|
108
|
-
updates_mask_3d = tf.minimum(updates_mask_3d, 1)
|
109
|
-
|
110
|
-
updated_sequence = (((1 - updates_mask_3d) * sequence) +
|
111
|
-
(updates_mask_3d * updates))
|
112
|
-
if not depth_dimension:
|
113
|
-
updated_sequence = tf.squeeze(updated_sequence, -1)
|
114
|
-
|
115
|
-
return updated_sequence, updates_mask
|
116
|
-
|
117
|
-
|
118
|
-
VOCAB_MAPPING = {}
|
119
|
-
|
120
|
-
|
121
|
-
def get_vocab(config: configure_pretraining.PretrainingConfig):
|
122
|
-
"""Memoized load of the vocab file."""
|
123
|
-
if config.vocab_file not in VOCAB_MAPPING:
|
124
|
-
vocab = tokenization.FullTokenizer(
|
125
|
-
config.vocab_file, do_lower_case=True).vocab
|
126
|
-
VOCAB_MAPPING[config.vocab_file] = vocab
|
127
|
-
return VOCAB_MAPPING[config.vocab_file]
|
128
|
-
|
129
|
-
|
130
|
-
def get_candidates_mask(config: configure_pretraining.PretrainingConfig,
|
131
|
-
inputs: pretrain_data.Inputs,
|
132
|
-
disallow_from_mask=None):
|
133
|
-
"""Returns a mask tensor of positions in the input that can be masked out."""
|
134
|
-
vocab = get_vocab(config)
|
135
|
-
ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]]
|
136
|
-
candidates_mask = tf.ones_like(inputs.input_ids, tf.bool)
|
137
|
-
for ignore_id in ignore_ids:
|
138
|
-
candidates_mask &= tf.not_equal(inputs.input_ids, ignore_id)
|
139
|
-
candidates_mask &= tf.cast(inputs.input_mask, tf.bool)
|
140
|
-
if disallow_from_mask is not None:
|
141
|
-
candidates_mask &= ~disallow_from_mask
|
142
|
-
return candidates_mask
|
143
|
-
|
144
|
-
|
145
|
-
def mask(config: configure_pretraining.PretrainingConfig,
|
146
|
-
inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
|
147
|
-
disallow_from_mask=None, already_masked=None):
|
148
|
-
"""Implementation of dynamic masking. The optional arguments aren't needed for
|
149
|
-
BERT/ELECTRA and are from early experiments in "strategically" masking out
|
150
|
-
tokens instead of uniformly at random.
|
151
|
-
|
152
|
-
Args:
|
153
|
-
config: configure_pretraining.PretrainingConfig
|
154
|
-
inputs: pretrain_data.Inputs containing input input_ids/input_mask
|
155
|
-
mask_prob: percent of tokens to mask
|
156
|
-
proposal_distribution: for non-uniform masking can be a [B, L] tensor
|
157
|
-
of scores for masking each position.
|
158
|
-
disallow_from_mask: a boolean tensor of [B, L] of positions that should
|
159
|
-
not be masked out
|
160
|
-
already_masked: a boolean tensor of [B, N] of already masked-out tokens
|
161
|
-
for multiple rounds of masking
|
162
|
-
Returns: a pretrain_data.Inputs with masking added
|
163
|
-
"""
|
164
|
-
# Get the batch size, sequence length, and max masked-out tokens
|
165
|
-
N = config.max_predictions_per_seq
|
166
|
-
B, L = modeling.get_shape_list(inputs.input_ids)
|
167
|
-
|
168
|
-
# Find indices where masking out a token is allowed
|
169
|
-
vocab = get_vocab(config)
|
170
|
-
candidates_mask = get_candidates_mask(config, inputs, disallow_from_mask)
|
171
|
-
|
172
|
-
# Set the number of tokens to mask out per example
|
173
|
-
num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
|
174
|
-
num_to_predict = tf.maximum(1, tf.minimum(
|
175
|
-
N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
|
176
|
-
masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
|
177
|
-
if already_masked is not None:
|
178
|
-
masked_lm_weights *= (1 - already_masked)
|
179
|
-
|
180
|
-
# Get a probability of masking each position in the sequence
|
181
|
-
candidate_mask_float = tf.cast(candidates_mask, tf.float32)
|
182
|
-
sample_prob = (proposal_distribution * candidate_mask_float)
|
183
|
-
sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)
|
184
|
-
|
185
|
-
# Sample the positions to mask out
|
186
|
-
sample_prob = tf.stop_gradient(sample_prob)
|
187
|
-
sample_logits = tf.log(sample_prob)
|
188
|
-
masked_lm_positions = tf.random.categorical(
|
189
|
-
sample_logits, N, dtype=tf.int32)
|
190
|
-
masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)
|
191
|
-
|
192
|
-
# Get the ids of the masked-out tokens
|
193
|
-
shift = tf.expand_dims(L * tf.range(B), -1)
|
194
|
-
flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
|
195
|
-
masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
|
196
|
-
flat_positions)
|
197
|
-
masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
|
198
|
-
masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)
|
199
|
-
|
200
|
-
# Update the input ids
|
201
|
-
replace_with_mask_positions = masked_lm_positions * tf.cast(
|
202
|
-
tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)
|
203
|
-
inputs_ids, _ = scatter_update(
|
204
|
-
inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
|
205
|
-
replace_with_mask_positions)
|
206
|
-
|
207
|
-
return pretrain_data.get_updated_inputs(
|
208
|
-
inputs,
|
209
|
-
input_ids=tf.stop_gradient(inputs_ids),
|
210
|
-
masked_lm_positions=masked_lm_positions,
|
211
|
-
masked_lm_ids=masked_lm_ids,
|
212
|
-
masked_lm_weights=masked_lm_weights
|
213
|
-
)
|
214
|
-
|
215
|
-
|
216
|
-
def unmask(inputs: pretrain_data.Inputs):
|
217
|
-
unmasked_input_ids, _ = scatter_update(
|
218
|
-
inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions)
|
219
|
-
return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids)
|
220
|
-
|
221
|
-
|
222
|
-
def sample_from_softmax(logits, disallow=None):
|
223
|
-
if disallow is not None:
|
224
|
-
logits -= 1000.0 * disallow
|
225
|
-
uniform_noise = tf.random.uniform(
|
226
|
-
modeling.get_shape_list(logits), minval=0, maxval=1)
|
227
|
-
gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9)
|
228
|
-
return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1,
|
229
|
-
output_type=tf.int32), logits.shape[-1])
|