secryst 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.adoc +103 -0
- data/lib/secryst-trainer.rb +8 -0
- data/lib/secryst.rb +11 -0
- data/lib/secryst/clip_grad_norm.rb +25 -0
- data/lib/secryst/multi_head_attention_forward.rb +288 -0
- data/lib/secryst/multihead_attention.rb +156 -0
- data/lib/secryst/trainer.rb +235 -0
- data/lib/secryst/transformer.rb +382 -0
- data/lib/secryst/translator.rb +51 -0
- data/lib/secryst/version.rb +3 -0
- data/lib/secryst/vocab.rb +88 -0
- metadata +95 -0
@@ -0,0 +1,51 @@
|
|
1
|
+
module Secryst
|
2
|
+
class Translator
|
3
|
+
def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
|
4
|
+
@device = "cpu"
|
5
|
+
@vocabs_dir = vocabs_dir
|
6
|
+
|
7
|
+
load_vocabs
|
8
|
+
|
9
|
+
if model == 'transformer'
|
10
|
+
@model = Secryst::Transformer.new(hyperparameters.merge({
|
11
|
+
input_vocab_size: @input_vocab.length,
|
12
|
+
target_vocab_size: @target_vocab.length,
|
13
|
+
}))
|
14
|
+
else
|
15
|
+
raise ArgumentError, 'Only transformer model is currently supported'
|
16
|
+
end
|
17
|
+
|
18
|
+
@model.load_state_dict(Torch.load(model_file))
|
19
|
+
@model.eval
|
20
|
+
end
|
21
|
+
|
22
|
+
def translate(phrase, max_seq_length: 100)
|
23
|
+
input = ['<sos>'] + phrase.chars + ['<eos>']
|
24
|
+
input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
|
25
|
+
output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
|
26
|
+
src_key_padding_mask = input.t.eq(1)
|
27
|
+
|
28
|
+
max_seq_length.times do |i|
|
29
|
+
tgt_key_padding_mask = output.t.eq(1)
|
30
|
+
tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
|
31
|
+
opts = {
|
32
|
+
tgt_mask: tgt_mask,
|
33
|
+
src_key_padding_mask: src_key_padding_mask,
|
34
|
+
tgt_key_padding_mask: tgt_key_padding_mask,
|
35
|
+
memory_key_padding_mask: src_key_padding_mask,
|
36
|
+
}
|
37
|
+
prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
|
38
|
+
break if @target_vocab.itos[prediction[i]] == '<eos>'
|
39
|
+
output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
|
40
|
+
end
|
41
|
+
|
42
|
+
puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
|
43
|
+
end
|
44
|
+
|
45
|
+
private
|
46
|
+
def load_vocabs
|
47
|
+
@input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
|
48
|
+
@target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
@@ -0,0 +1,88 @@
|
|
1
|
+
module Secryst
|
2
|
+
class Vocab
|
3
|
+
UNK = "<unk>"
|
4
|
+
attr_reader :stoi, :itos, :freqs
|
5
|
+
|
6
|
+
def initialize(
|
7
|
+
counter, max_size: nil, min_freq: 1, specials: ["<unk>", "<pad>", "<sos>", "<eos>"],
|
8
|
+
vectors: nil, unk_init: nil, vectors_cache: nil, specials_first: true
|
9
|
+
)
|
10
|
+
|
11
|
+
@freqs = counter
|
12
|
+
counter = counter.dup
|
13
|
+
min_freq = [min_freq, 1].max
|
14
|
+
|
15
|
+
@itos = []
|
16
|
+
@unk_index = nil
|
17
|
+
|
18
|
+
if specials_first
|
19
|
+
@itos = specials
|
20
|
+
# only extend max size if specials are prepended
|
21
|
+
max_size += specials.size if max_size
|
22
|
+
end
|
23
|
+
|
24
|
+
# frequencies of special tokens are not counted when building vocabulary
|
25
|
+
# in frequency order
|
26
|
+
specials.each do |tok|
|
27
|
+
counter.delete(tok)
|
28
|
+
end
|
29
|
+
|
30
|
+
# sort by frequency, then alphabetically
|
31
|
+
words_and_frequencies = counter.sort_by { |k, v| [-v, k] }
|
32
|
+
|
33
|
+
words_and_frequencies.each do |word, freq|
|
34
|
+
break if freq < min_freq || @itos.length == max_size
|
35
|
+
@itos << word
|
36
|
+
end
|
37
|
+
|
38
|
+
if specials.include?(UNK) # hard-coded for now
|
39
|
+
unk_index = specials.index(UNK) # position in list
|
40
|
+
# account for ordering of specials, set variable
|
41
|
+
@unk_index = specials_first ? unk_index : @itos.length + unk_index
|
42
|
+
@stoi = Hash.new(@unk_index)
|
43
|
+
else
|
44
|
+
@stoi = {}
|
45
|
+
end
|
46
|
+
|
47
|
+
if !specials_first
|
48
|
+
@itos.concat(specials)
|
49
|
+
end
|
50
|
+
|
51
|
+
# stoi is simply a reverse dict for itos
|
52
|
+
@itos.each_with_index do |tok, i|
|
53
|
+
@stoi[tok] = i
|
54
|
+
end
|
55
|
+
|
56
|
+
@vectors = nil
|
57
|
+
if !vectors.nil?
|
58
|
+
# self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
|
59
|
+
raise "Not implemented yet"
|
60
|
+
else
|
61
|
+
raise "Failed assertion" unless unk_init.nil?
|
62
|
+
raise "Failed assertion" unless vectors_cache.nil?
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def [](token)
|
67
|
+
@stoi.fetch(token, @stoi.fetch(UNK))
|
68
|
+
end
|
69
|
+
|
70
|
+
def length
|
71
|
+
@itos.length
|
72
|
+
end
|
73
|
+
alias_method :size, :length
|
74
|
+
|
75
|
+
def self.build_vocab_from_iterator(iterator)
|
76
|
+
counter = Hash.new(0)
|
77
|
+
i = 0
|
78
|
+
iterator.each do |tokens|
|
79
|
+
tokens.each do |token|
|
80
|
+
counter[token] += 1
|
81
|
+
end
|
82
|
+
i += 1
|
83
|
+
puts "Processed #{i}" if i % 10000 == 0
|
84
|
+
end
|
85
|
+
Vocab.new(counter)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
metadata
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: secryst
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- project_contibutors
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2020-10-05 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: torch-rb
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0.4'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0.4'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rake
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rspec
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
description: Seq2seq transformer for transliteration in Ruby.
|
56
|
+
email:
|
57
|
+
executables: []
|
58
|
+
extensions: []
|
59
|
+
extra_rdoc_files: []
|
60
|
+
files:
|
61
|
+
- README.adoc
|
62
|
+
- lib/secryst-trainer.rb
|
63
|
+
- lib/secryst.rb
|
64
|
+
- lib/secryst/clip_grad_norm.rb
|
65
|
+
- lib/secryst/multi_head_attention_forward.rb
|
66
|
+
- lib/secryst/multihead_attention.rb
|
67
|
+
- lib/secryst/trainer.rb
|
68
|
+
- lib/secryst/transformer.rb
|
69
|
+
- lib/secryst/translator.rb
|
70
|
+
- lib/secryst/version.rb
|
71
|
+
- lib/secryst/vocab.rb
|
72
|
+
homepage: https://github.com/secryst/secryst
|
73
|
+
licenses:
|
74
|
+
- BSD-2-Clause
|
75
|
+
metadata: {}
|
76
|
+
post_install_message:
|
77
|
+
rdoc_options: []
|
78
|
+
require_paths:
|
79
|
+
- lib
|
80
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
81
|
+
requirements:
|
82
|
+
- - ">="
|
83
|
+
- !ruby/object:Gem::Version
|
84
|
+
version: '2.7'
|
85
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
requirements: []
|
91
|
+
rubygems_version: 3.1.2
|
92
|
+
signing_key:
|
93
|
+
specification_version: 4
|
94
|
+
summary: Seq2seq transformer for transliteration in Ruby.
|
95
|
+
test_files: []
|