secryst-trainer 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/README.adoc +103 -0
- data/lib/secryst-trainer.rb +8 -0
- data/lib/secryst.rb +11 -0
- data/lib/secryst/clip_grad_norm.rb +25 -0
- data/lib/secryst/multi_head_attention_forward.rb +288 -0
- data/lib/secryst/multihead_attention.rb +156 -0
- data/lib/secryst/trainer.rb +235 -0
- data/lib/secryst/transformer.rb +382 -0
- data/lib/secryst/translator.rb +51 -0
- data/lib/secryst/version.rb +3 -0
- data/lib/secryst/vocab.rb +88 -0
- metadata +138 -0
@@ -0,0 +1,51 @@
|
|
1
|
+
module Secryst
|
2
|
+
class Translator
|
3
|
+
def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
|
4
|
+
@device = "cpu"
|
5
|
+
@vocabs_dir = vocabs_dir
|
6
|
+
|
7
|
+
load_vocabs
|
8
|
+
|
9
|
+
if model == 'transformer'
|
10
|
+
@model = Secryst::Transformer.new(hyperparameters.merge({
|
11
|
+
input_vocab_size: @input_vocab.length,
|
12
|
+
target_vocab_size: @target_vocab.length,
|
13
|
+
}))
|
14
|
+
else
|
15
|
+
raise ArgumentError, 'Only transformer model is currently supported'
|
16
|
+
end
|
17
|
+
|
18
|
+
@model.load_state_dict(Torch.load(model_file))
|
19
|
+
@model.eval
|
20
|
+
end
|
21
|
+
|
22
|
+
def translate(phrase, max_seq_length: 100)
|
23
|
+
input = ['<sos>'] + phrase.chars + ['<eos>']
|
24
|
+
input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
|
25
|
+
output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
|
26
|
+
src_key_padding_mask = input.t.eq(1)
|
27
|
+
|
28
|
+
max_seq_length.times do |i|
|
29
|
+
tgt_key_padding_mask = output.t.eq(1)
|
30
|
+
tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
|
31
|
+
opts = {
|
32
|
+
tgt_mask: tgt_mask,
|
33
|
+
src_key_padding_mask: src_key_padding_mask,
|
34
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
35
|
+
memory_key_padding_mask: src_key_padding_mask,
|
36
|
+
}
|
37
|
+
prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
|
38
|
+
break if @target_vocab.itos[prediction[i]] == '<eos>'
|
39
|
+
output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
|
40
|
+
end
|
41
|
+
|
42
|
+
puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
def load_vocabs
|
47
|
+
@input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
|
48
|
+
@target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,88 @@
|
|
1
|
+
module Secryst
|
2
|
+
class Vocab
|
3
|
+
UNK = "<unk>"
|
4
|
+
attr_reader :stoi, :itos, :freqs
|
5
|
+
|
6
|
+
def initialize(
|
7
|
+
counter, max_size: nil, min_freq: 1, specials: ["<unk>", "<pad>", "<sos>", "<eos>"],
|
8
|
+
vectors: nil, unk_init: nil, vectors_cache: nil, specials_first: true
|
9
|
+
)
|
10
|
+
|
11
|
+
@freqs = counter
|
12
|
+
counter = counter.dup
|
13
|
+
min_freq = [min_freq, 1].max
|
14
|
+
|
15
|
+
@itos = []
|
16
|
+
@unk_index = nil
|
17
|
+
|
18
|
+
if specials_first
|
19
|
+
@itos = specials
|
20
|
+
# only extend max size if specials are prepended
|
21
|
+
max_size += specials.size if max_size
|
22
|
+
end
|
23
|
+
|
24
|
+
# frequencies of special tokens are not counted when building vocabulary
|
25
|
+
# in frequency order
|
26
|
+
specials.each do |tok|
|
27
|
+
counter.delete(tok)
|
28
|
+
end
|
29
|
+
|
30
|
+
# sort by frequency, then alphabetically
|
31
|
+
words_and_frequencies = counter.sort_by { |k, v| [-v, k] }
|
32
|
+
|
33
|
+
words_and_frequencies.each do |word, freq|
|
34
|
+
break if freq < min_freq || @itos.length == max_size
|
35
|
+
@itos << word
|
36
|
+
end
|
37
|
+
|
38
|
+
if specials.include?(UNK) # hard-coded for now
|
39
|
+
unk_index = specials.index(UNK) # position in list
|
40
|
+
# account for ordering of specials, set variable
|
41
|
+
@unk_index = specials_first ? unk_index : @itos.length + unk_index
|
42
|
+
@stoi = Hash.new(@unk_index)
|
43
|
+
else
|
44
|
+
@stoi = {}
|
45
|
+
end
|
46
|
+
|
47
|
+
if !specials_first
|
48
|
+
@itos.concat(specials)
|
49
|
+
end
|
50
|
+
|
51
|
+
# stoi is simply a reverse dict for itos
|
52
|
+
@itos.each_with_index do |tok, i|
|
53
|
+
@stoi[tok] = i
|
54
|
+
end
|
55
|
+
|
56
|
+
@vectors = nil
|
57
|
+
if !vectors.nil?
|
58
|
+
# self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
|
59
|
+
raise "Not implemented yet"
|
60
|
+
else
|
61
|
+
raise "Failed assertion" unless unk_init.nil?
|
62
|
+
raise "Failed assertion" unless vectors_cache.nil?
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def [](token)
|
67
|
+
@stoi.fetch(token, @stoi.fetch(UNK))
|
68
|
+
end
|
69
|
+
|
70
|
+
def length
|
71
|
+
@itos.length
|
72
|
+
end
|
73
|
+
alias_method :size, :length
|
74
|
+
|
75
|
+
def self.build_vocab_from_iterator(iterator)
|
76
|
+
counter = Hash.new(0)
|
77
|
+
i = 0
|
78
|
+
iterator.each do |tokens|
|
79
|
+
tokens.each do |token|
|
80
|
+
counter[token] += 1
|
81
|
+
end
|
82
|
+
i += 1
|
83
|
+
puts "Processed #{i}" if i % 10000 == 0
|
84
|
+
end
|
85
|
+
Vocab.new(counter)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
metadata
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: secryst-trainer
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- project_contibutors
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2020-10-05 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: torch-rb
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0.4'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0.4'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: numo
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0.1'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0.1'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: numo-linalg
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0.1'
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0.1'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: secryst
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - '='
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: 0.1.0
|
62
|
+
type: :runtime
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - '='
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: 0.1.0
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: rake
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '0'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '0'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: rspec
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
description:
|
98
|
+
email:
|
99
|
+
executables: []
|
100
|
+
extensions: []
|
101
|
+
extra_rdoc_files: []
|
102
|
+
files:
|
103
|
+
- README.adoc
|
104
|
+
- lib/secryst-trainer.rb
|
105
|
+
- lib/secryst.rb
|
106
|
+
- lib/secryst/clip_grad_norm.rb
|
107
|
+
- lib/secryst/multi_head_attention_forward.rb
|
108
|
+
- lib/secryst/multihead_attention.rb
|
109
|
+
- lib/secryst/trainer.rb
|
110
|
+
- lib/secryst/transformer.rb
|
111
|
+
- lib/secryst/translator.rb
|
112
|
+
- lib/secryst/version.rb
|
113
|
+
- lib/secryst/vocab.rb
|
114
|
+
homepage: https://github.com/secryst/secryst
|
115
|
+
licenses:
|
116
|
+
- BSD-2-Clause
|
117
|
+
metadata: {}
|
118
|
+
post_install_message:
|
119
|
+
rdoc_options: []
|
120
|
+
require_paths:
|
121
|
+
- lib
|
122
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
123
|
+
requirements:
|
124
|
+
- - ">="
|
125
|
+
- !ruby/object:Gem::Version
|
126
|
+
version: '2.7'
|
127
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
128
|
+
requirements:
|
129
|
+
- - ">="
|
130
|
+
- !ruby/object:Gem::Version
|
131
|
+
version: '0'
|
132
|
+
requirements: []
|
133
|
+
rubygems_version: 3.1.2
|
134
|
+
signing_key:
|
135
|
+
specification_version: 4
|
136
|
+
summary: A seq2seq transformer suited for transliteration. Written in Ruby. Includes
|
137
|
+
packages for training models
|
138
|
+
test_files: []
|