SinaTools 0.1.40__py2.py3-none-any.whl → 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/METADATA +1 -1
  2. SinaTools-1.0.1.dist-info/RECORD +73 -0
  3. sinatools/VERSION +1 -1
  4. sinatools/ner/__init__.py +5 -7
  5. sinatools/ner/trainers/BertNestedTrainer.py +203 -203
  6. sinatools/ner/trainers/BertTrainer.py +163 -163
  7. sinatools/ner/trainers/__init__.py +2 -2
  8. SinaTools-0.1.40.dist-info/RECORD +0 -123
  9. sinatools/arabert/arabert/__init__.py +0 -14
  10. sinatools/arabert/arabert/create_classification_data.py +0 -260
  11. sinatools/arabert/arabert/create_pretraining_data.py +0 -534
  12. sinatools/arabert/arabert/extract_features.py +0 -444
  13. sinatools/arabert/arabert/lamb_optimizer.py +0 -158
  14. sinatools/arabert/arabert/modeling.py +0 -1027
  15. sinatools/arabert/arabert/optimization.py +0 -202
  16. sinatools/arabert/arabert/run_classifier.py +0 -1078
  17. sinatools/arabert/arabert/run_pretraining.py +0 -593
  18. sinatools/arabert/arabert/run_squad.py +0 -1440
  19. sinatools/arabert/arabert/tokenization.py +0 -414
  20. sinatools/arabert/araelectra/__init__.py +0 -1
  21. sinatools/arabert/araelectra/build_openwebtext_pretraining_dataset.py +0 -103
  22. sinatools/arabert/araelectra/build_pretraining_dataset.py +0 -230
  23. sinatools/arabert/araelectra/build_pretraining_dataset_single_file.py +0 -90
  24. sinatools/arabert/araelectra/configure_finetuning.py +0 -172
  25. sinatools/arabert/araelectra/configure_pretraining.py +0 -143
  26. sinatools/arabert/araelectra/finetune/__init__.py +0 -14
  27. sinatools/arabert/araelectra/finetune/feature_spec.py +0 -56
  28. sinatools/arabert/araelectra/finetune/preprocessing.py +0 -173
  29. sinatools/arabert/araelectra/finetune/scorer.py +0 -54
  30. sinatools/arabert/araelectra/finetune/task.py +0 -74
  31. sinatools/arabert/araelectra/finetune/task_builder.py +0 -70
  32. sinatools/arabert/araelectra/flops_computation.py +0 -215
  33. sinatools/arabert/araelectra/model/__init__.py +0 -14
  34. sinatools/arabert/araelectra/model/modeling.py +0 -1029
  35. sinatools/arabert/araelectra/model/optimization.py +0 -193
  36. sinatools/arabert/araelectra/model/tokenization.py +0 -355
  37. sinatools/arabert/araelectra/pretrain/__init__.py +0 -14
  38. sinatools/arabert/araelectra/pretrain/pretrain_data.py +0 -160
  39. sinatools/arabert/araelectra/pretrain/pretrain_helpers.py +0 -229
  40. sinatools/arabert/araelectra/run_finetuning.py +0 -323
  41. sinatools/arabert/araelectra/run_pretraining.py +0 -469
  42. sinatools/arabert/araelectra/util/__init__.py +0 -14
  43. sinatools/arabert/araelectra/util/training_utils.py +0 -112
  44. sinatools/arabert/araelectra/util/utils.py +0 -109
  45. sinatools/arabert/aragpt2/__init__.py +0 -2
  46. sinatools/arabert/aragpt2/create_pretraining_data.py +0 -95
  47. sinatools/arabert/aragpt2/gpt2/__init__.py +0 -2
  48. sinatools/arabert/aragpt2/gpt2/lamb_optimizer.py +0 -158
  49. sinatools/arabert/aragpt2/gpt2/optimization.py +0 -225
  50. sinatools/arabert/aragpt2/gpt2/run_pretraining.py +0 -397
  51. sinatools/arabert/aragpt2/grover/__init__.py +0 -0
  52. sinatools/arabert/aragpt2/grover/dataloader.py +0 -161
  53. sinatools/arabert/aragpt2/grover/modeling.py +0 -803
  54. sinatools/arabert/aragpt2/grover/modeling_gpt2.py +0 -1196
  55. sinatools/arabert/aragpt2/grover/optimization_adafactor.py +0 -234
  56. sinatools/arabert/aragpt2/grover/train_tpu.py +0 -187
  57. sinatools/arabert/aragpt2/grover/utils.py +0 -234
  58. sinatools/arabert/aragpt2/train_bpe_tokenizer.py +0 -59
  59. {SinaTools-0.1.40.data → SinaTools-1.0.1.data}/data/sinatools/environment.yml +0 -0
  60. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/AUTHORS.rst +0 -0
  61. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/LICENSE +0 -0
  62. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/WHEEL +0 -0
  63. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/entry_points.txt +0 -0
  64. {SinaTools-0.1.40.dist-info → SinaTools-1.0.1.dist-info}/top_level.txt +0 -0
@@ -1,444 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2018 The Google AI Language Team Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Extract pre-computed feature vectors from BERT."""
16
-
17
- from __future__ import absolute_import
18
- from __future__ import division
19
- from __future__ import print_function
20
-
21
- import codecs
22
- import collections
23
- import json
24
- import re
25
-
26
- import modeling
27
- import tokenization
28
- import tensorflow as tf
29
-
30
- flags = tf.flags
31
-
32
- FLAGS = flags.FLAGS
33
-
34
- flags.DEFINE_string("input_file", None, "")
35
-
36
- flags.DEFINE_string("output_file", None, "")
37
-
38
- flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
39
-
40
- flags.DEFINE_string(
41
- "bert_config_file",
42
- None,
43
- "The config json file corresponding to the pre-trained BERT model. "
44
- "This specifies the model architecture.",
45
- )
46
-
47
- flags.DEFINE_integer(
48
- "max_seq_length",
49
- 128,
50
- "The maximum total input sequence length after WordPiece tokenization. "
51
- "Sequences longer than this will be truncated, and sequences shorter "
52
- "than this will be padded.",
53
- )
54
-
55
- flags.DEFINE_string(
56
- "init_checkpoint",
57
- None,
58
- "Initial checkpoint (usually from a pre-trained BERT model).",
59
- )
60
-
61
- flags.DEFINE_string(
62
- "vocab_file", None, "The vocabulary file that the BERT model was trained on."
63
- )
64
-
65
- flags.DEFINE_bool(
66
- "do_lower_case",
67
- True,
68
- "Whether to lower case the input text. Should be True for uncased "
69
- "models and False for cased models.",
70
- )
71
-
72
- flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
73
-
74
- flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
75
-
76
- flags.DEFINE_string("master", None, "If using a TPU, the address of the master.")
77
-
78
- flags.DEFINE_integer(
79
- "num_tpu_cores",
80
- 8,
81
- "Only used if `use_tpu` is True. Total number of TPU cores to use.",
82
- )
83
-
84
- flags.DEFINE_bool(
85
- "use_one_hot_embeddings",
86
- False,
87
- "If True, tf.one_hot will be used for embedding lookups, otherwise "
88
- "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
89
- "since it is much faster.",
90
- )
91
-
92
-
93
- class InputExample(object):
94
- def __init__(self, unique_id, text_a, text_b):
95
- self.unique_id = unique_id
96
- self.text_a = text_a
97
- self.text_b = text_b
98
-
99
-
100
- class InputFeatures(object):
101
- """A single set of features of data."""
102
-
103
- def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
104
- self.unique_id = unique_id
105
- self.tokens = tokens
106
- self.input_ids = input_ids
107
- self.input_mask = input_mask
108
- self.input_type_ids = input_type_ids
109
-
110
-
111
- def input_fn_builder(features, seq_length):
112
- """Creates an `input_fn` closure to be passed to TPUEstimator."""
113
-
114
- all_unique_ids = []
115
- all_input_ids = []
116
- all_input_mask = []
117
- all_input_type_ids = []
118
-
119
- for feature in features:
120
- all_unique_ids.append(feature.unique_id)
121
- all_input_ids.append(feature.input_ids)
122
- all_input_mask.append(feature.input_mask)
123
- all_input_type_ids.append(feature.input_type_ids)
124
-
125
- def input_fn(params):
126
- """The actual input function."""
127
- batch_size = params["batch_size"]
128
-
129
- num_examples = len(features)
130
-
131
- # This is for demo purposes and does NOT scale to large data sets. We do
132
- # not use Dataset.from_generator() because that uses tf.py_func which is
133
- # not TPU compatible. The right way to load data is with TFRecordReader.
134
- d = tf.data.Dataset.from_tensor_slices(
135
- {
136
- "unique_ids": tf.constant(
137
- all_unique_ids, shape=[num_examples], dtype=tf.int32
138
- ),
139
- "input_ids": tf.constant(
140
- all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32
141
- ),
142
- "input_mask": tf.constant(
143
- all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32
144
- ),
145
- "input_type_ids": tf.constant(
146
- all_input_type_ids, shape=[num_examples, seq_length], dtype=tf.int32
147
- ),
148
- }
149
- )
150
-
151
- d = d.batch(batch_size=batch_size, drop_remainder=False)
152
- return d
153
-
154
- return input_fn
155
-
156
-
157
- def model_fn_builder(
158
- bert_config, init_checkpoint, layer_indexes, use_tpu, use_one_hot_embeddings
159
- ):
160
- """Returns `model_fn` closure for TPUEstimator."""
161
-
162
- def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
163
- """The `model_fn` for TPUEstimator."""
164
-
165
- unique_ids = features["unique_ids"]
166
- input_ids = features["input_ids"]
167
- input_mask = features["input_mask"]
168
- input_type_ids = features["input_type_ids"]
169
-
170
- model = modeling.BertModel(
171
- config=bert_config,
172
- is_training=False,
173
- input_ids=input_ids,
174
- input_mask=input_mask,
175
- token_type_ids=input_type_ids,
176
- use_one_hot_embeddings=use_one_hot_embeddings,
177
- )
178
-
179
- if mode != tf.estimator.ModeKeys.PREDICT:
180
- raise ValueError("Only PREDICT modes are supported: %s" % (mode))
181
-
182
- tvars = tf.trainable_variables()
183
- scaffold_fn = None
184
- (
185
- assignment_map,
186
- initialized_variable_names,
187
- ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
188
- if use_tpu:
189
-
190
- def tpu_scaffold():
191
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
192
- return tf.train.Scaffold()
193
-
194
- scaffold_fn = tpu_scaffold
195
- else:
196
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
197
-
198
- tf.logging.info("**** Trainable Variables ****")
199
- for var in tvars:
200
- init_string = ""
201
- if var.name in initialized_variable_names:
202
- init_string = ", *INIT_FROM_CKPT*"
203
- tf.logging.info(
204
- " name = %s, shape = %s%s", var.name, var.shape, init_string
205
- )
206
-
207
- all_layers = model.get_all_encoder_layers()
208
-
209
- predictions = {
210
- "unique_id": unique_ids,
211
- }
212
-
213
- for (i, layer_index) in enumerate(layer_indexes):
214
- predictions["layer_output_%d" % i] = all_layers[layer_index]
215
-
216
- output_spec = tf.contrib.tpu.TPUEstimatorSpec(
217
- mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
218
- )
219
- return output_spec
220
-
221
- return model_fn
222
-
223
-
224
- def convert_examples_to_features(examples, seq_length, tokenizer):
225
- """Loads a data file into a list of `InputBatch`s."""
226
-
227
- features = []
228
- for (ex_index, example) in enumerate(examples):
229
- tokens_a = tokenizer.tokenize(example.text_a)
230
-
231
- tokens_b = None
232
- if example.text_b:
233
- tokens_b = tokenizer.tokenize(example.text_b)
234
-
235
- if tokens_b:
236
- # Modifies `tokens_a` and `tokens_b` in place so that the total
237
- # length is less than the specified length.
238
- # Account for [CLS], [SEP], [SEP] with "- 3"
239
- _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
240
- else:
241
- # Account for [CLS] and [SEP] with "- 2"
242
- if len(tokens_a) > seq_length - 2:
243
- tokens_a = tokens_a[0 : (seq_length - 2)]
244
-
245
- # The convention in BERT is:
246
- # (a) For sequence pairs:
247
- # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
248
- # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
249
- # (b) For single sequences:
250
- # tokens: [CLS] the dog is hairy . [SEP]
251
- # type_ids: 0 0 0 0 0 0 0
252
- #
253
- # Where "type_ids" are used to indicate whether this is the first
254
- # sequence or the second sequence. The embedding vectors for `type=0` and
255
- # `type=1` were learned during pre-training and are added to the wordpiece
256
- # embedding vector (and position vector). This is not *strictly* necessary
257
- # since the [SEP] token unambiguously separates the sequences, but it makes
258
- # it easier for the model to learn the concept of sequences.
259
- #
260
- # For classification tasks, the first vector (corresponding to [CLS]) is
261
- # used as as the "sentence vector". Note that this only makes sense because
262
- # the entire model is fine-tuned.
263
- tokens = []
264
- input_type_ids = []
265
- tokens.append("[CLS]")
266
- input_type_ids.append(0)
267
- for token in tokens_a:
268
- tokens.append(token)
269
- input_type_ids.append(0)
270
- tokens.append("[SEP]")
271
- input_type_ids.append(0)
272
-
273
- if tokens_b:
274
- for token in tokens_b:
275
- tokens.append(token)
276
- input_type_ids.append(1)
277
- tokens.append("[SEP]")
278
- input_type_ids.append(1)
279
-
280
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
281
-
282
- # The mask has 1 for real tokens and 0 for padding tokens. Only real
283
- # tokens are attended to.
284
- input_mask = [1] * len(input_ids)
285
-
286
- # Zero-pad up to the sequence length.
287
- while len(input_ids) < seq_length:
288
- input_ids.append(0)
289
- input_mask.append(0)
290
- input_type_ids.append(0)
291
-
292
- assert len(input_ids) == seq_length
293
- assert len(input_mask) == seq_length
294
- assert len(input_type_ids) == seq_length
295
-
296
- if ex_index < 5:
297
- tf.logging.info("*** Example ***")
298
- tf.logging.info("unique_id: %s" % (example.unique_id))
299
- tf.logging.info(
300
- "tokens: %s"
301
- % " ".join([tokenization.printable_text(x) for x in tokens])
302
- )
303
- tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
304
- tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
305
- tf.logging.info(
306
- "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])
307
- )
308
-
309
- features.append(
310
- InputFeatures(
311
- unique_id=example.unique_id,
312
- tokens=tokens,
313
- input_ids=input_ids,
314
- input_mask=input_mask,
315
- input_type_ids=input_type_ids,
316
- )
317
- )
318
- return features
319
-
320
-
321
- def _truncate_seq_pair(tokens_a, tokens_b, max_length):
322
- """Truncates a sequence pair in place to the maximum length."""
323
-
324
- # This is a simple heuristic which will always truncate the longer sequence
325
- # one token at a time. This makes more sense than truncating an equal percent
326
- # of tokens from each, since if one sequence is very short then each token
327
- # that's truncated likely contains more information than a longer sequence.
328
- while True:
329
- total_length = len(tokens_a) + len(tokens_b)
330
- if total_length <= max_length:
331
- break
332
- if len(tokens_a) > len(tokens_b):
333
- tokens_a.pop()
334
- else:
335
- tokens_b.pop()
336
-
337
-
338
- def read_examples(input_file):
339
- """Read a list of `InputExample`s from an input file."""
340
- examples = []
341
- unique_id = 0
342
- with tf.gfile.GFile(input_file, "r") as reader:
343
- while True:
344
- line = tokenization.convert_to_unicode(reader.readline())
345
- if not line:
346
- break
347
- line = line.strip()
348
- text_a = None
349
- text_b = None
350
- m = re.match(r"^(.*) \|\|\| (.*)$", line)
351
- if m is None:
352
- text_a = line
353
- else:
354
- text_a = m.group(1)
355
- text_b = m.group(2)
356
- examples.append(
357
- InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)
358
- )
359
- unique_id += 1
360
- return examples
361
-
362
-
363
- def main(_):
364
- tf.logging.set_verbosity(tf.logging.INFO)
365
- logger = tf.get_logger()
366
- logger.propagate = False
367
-
368
- layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
369
-
370
- bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
371
-
372
- tokenizer = tokenization.FullTokenizer(
373
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
374
- )
375
-
376
- is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
377
- run_config = tf.contrib.tpu.RunConfig(
378
- master=FLAGS.master,
379
- tpu_config=tf.contrib.tpu.TPUConfig(
380
- num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host
381
- ),
382
- )
383
-
384
- examples = read_examples(FLAGS.input_file)
385
-
386
- features = convert_examples_to_features(
387
- examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer
388
- )
389
-
390
- unique_id_to_feature = {}
391
- for feature in features:
392
- unique_id_to_feature[feature.unique_id] = feature
393
-
394
- model_fn = model_fn_builder(
395
- bert_config=bert_config,
396
- init_checkpoint=FLAGS.init_checkpoint,
397
- layer_indexes=layer_indexes,
398
- use_tpu=FLAGS.use_tpu,
399
- use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
400
- )
401
-
402
- # If TPU is not available, this will fall back to normal Estimator on CPU
403
- # or GPU.
404
- estimator = tf.contrib.tpu.TPUEstimator(
405
- use_tpu=FLAGS.use_tpu,
406
- model_fn=model_fn,
407
- config=run_config,
408
- predict_batch_size=FLAGS.batch_size,
409
- )
410
-
411
- input_fn = input_fn_builder(features=features, seq_length=FLAGS.max_seq_length)
412
-
413
- with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, "w")) as writer:
414
- for result in estimator.predict(input_fn, yield_single_examples=True):
415
- unique_id = int(result["unique_id"])
416
- feature = unique_id_to_feature[unique_id]
417
- output_json = collections.OrderedDict()
418
- output_json["linex_index"] = unique_id
419
- all_features = []
420
- for (i, token) in enumerate(feature.tokens):
421
- all_layers = []
422
- for (j, layer_index) in enumerate(layer_indexes):
423
- layer_output = result["layer_output_%d" % j]
424
- layers = collections.OrderedDict()
425
- layers["index"] = layer_index
426
- layers["values"] = [
427
- round(float(x), 6) for x in layer_output[i : (i + 1)].flat
428
- ]
429
- all_layers.append(layers)
430
- features = collections.OrderedDict()
431
- features["token"] = token
432
- features["layers"] = all_layers
433
- all_features.append(features)
434
- output_json["features"] = all_features
435
- writer.write(json.dumps(output_json) + "\n")
436
-
437
-
438
- if __name__ == "__main__":
439
- flags.mark_flag_as_required("input_file")
440
- flags.mark_flag_as_required("vocab_file")
441
- flags.mark_flag_as_required("bert_config_file")
442
- flags.mark_flag_as_required("init_checkpoint")
443
- flags.mark_flag_as_required("output_file")
444
- tf.app.run()
@@ -1,158 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2019 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Lint as: python2, python3
17
- """Functions and classes related to optimization (weight updates)."""
18
-
19
- from __future__ import absolute_import
20
- from __future__ import division
21
- from __future__ import print_function
22
-
23
- import re
24
- import six
25
- import tensorflow as tf
26
-
27
- # pylint: disable=g-direct-tensorflow-import
28
- from tensorflow.python.ops import array_ops
29
- from tensorflow.python.ops import linalg_ops
30
- from tensorflow.python.ops import math_ops
31
-
32
- # pylint: enable=g-direct-tensorflow-import
33
-
34
-
35
- class LAMBOptimizer(tf.train.Optimizer):
36
- """LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
37
-
38
- # A new optimizer that includes correct L2 weight decay, adaptive
39
- # element-wise updating, and layer-wise justification. The LAMB optimizer
40
- # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
41
- # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
42
- # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
43
-
44
- def __init__(
45
- self,
46
- learning_rate,
47
- weight_decay_rate=0.0,
48
- beta_1=0.9,
49
- beta_2=0.999,
50
- epsilon=1e-6,
51
- exclude_from_weight_decay=None,
52
- exclude_from_layer_adaptation=None,
53
- name="LAMBOptimizer",
54
- ):
55
- """Constructs a LAMBOptimizer."""
56
- super(LAMBOptimizer, self).__init__(False, name)
57
-
58
- self.learning_rate = learning_rate
59
- self.weight_decay_rate = weight_decay_rate
60
- self.beta_1 = beta_1
61
- self.beta_2 = beta_2
62
- self.epsilon = epsilon
63
- self.exclude_from_weight_decay = exclude_from_weight_decay
64
- # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
65
- # arg is None.
66
- # TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
67
- if exclude_from_layer_adaptation:
68
- self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
69
- else:
70
- self.exclude_from_layer_adaptation = exclude_from_weight_decay
71
-
72
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
73
- """See base class."""
74
- assignments = []
75
- for (grad, param) in grads_and_vars:
76
- if grad is None or param is None:
77
- continue
78
-
79
- param_name = self._get_variable_name(param.name)
80
-
81
- m = tf.get_variable(
82
- name=six.ensure_str(param_name) + "/adam_m",
83
- shape=param.shape.as_list(),
84
- dtype=tf.float32,
85
- trainable=False,
86
- initializer=tf.zeros_initializer(),
87
- )
88
- v = tf.get_variable(
89
- name=six.ensure_str(param_name) + "/adam_v",
90
- shape=param.shape.as_list(),
91
- dtype=tf.float32,
92
- trainable=False,
93
- initializer=tf.zeros_initializer(),
94
- )
95
-
96
- # Standard Adam update.
97
- next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)
98
- next_v = tf.multiply(self.beta_2, v) + tf.multiply(
99
- 1.0 - self.beta_2, tf.square(grad)
100
- )
101
-
102
- update = next_m / (tf.sqrt(next_v) + self.epsilon)
103
-
104
- # Just adding the square of the weights to the loss function is *not*
105
- # the correct way of using L2 regularization/weight decay with Adam,
106
- # since that will interact with the m and v parameters in strange ways.
107
- #
108
- # Instead we want ot decay the weights in a manner that doesn't interact
109
- # with the m/v parameters. This is equivalent to adding the square
110
- # of the weights to the loss with plain (non-momentum) SGD.
111
- if self._do_use_weight_decay(param_name):
112
- update += self.weight_decay_rate * param
113
-
114
- ratio = 1.0
115
- if self._do_layer_adaptation(param_name):
116
- w_norm = linalg_ops.norm(param, ord=2)
117
- g_norm = linalg_ops.norm(update, ord=2)
118
- ratio = array_ops.where(
119
- math_ops.greater(w_norm, 0),
120
- array_ops.where(
121
- math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0
122
- ),
123
- 1.0,
124
- )
125
-
126
- update_with_lr = ratio * self.learning_rate * update
127
-
128
- next_param = param - update_with_lr
129
-
130
- assignments.extend(
131
- [param.assign(next_param), m.assign(next_m), v.assign(next_v)]
132
- )
133
- return tf.group(*assignments, name=name)
134
-
135
- def _do_use_weight_decay(self, param_name):
136
- """Whether to use L2 weight decay for `param_name`."""
137
- if not self.weight_decay_rate:
138
- return False
139
- if self.exclude_from_weight_decay:
140
- for r in self.exclude_from_weight_decay:
141
- if re.search(r, param_name) is not None:
142
- return False
143
- return True
144
-
145
- def _do_layer_adaptation(self, param_name):
146
- """Whether to do layer-wise learning rate adaptation for `param_name`."""
147
- if self.exclude_from_layer_adaptation:
148
- for r in self.exclude_from_layer_adaptation:
149
- if re.search(r, param_name) is not None:
150
- return False
151
- return True
152
-
153
- def _get_variable_name(self, param_name):
154
- """Get the variable name from the tensor name."""
155
- m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
156
- if m is not None:
157
- param_name = m.group(1)
158
- return param_name