secryst-trainer 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,51 @@
1
+ module Secryst
2
+ class Translator
3
+ def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
4
+ @device = "cpu"
5
+ @vocabs_dir = vocabs_dir
6
+
7
+ load_vocabs
8
+
9
+ if model == 'transformer'
10
+ @model = Secryst::Transformer.new(hyperparameters.merge({
11
+ input_vocab_size: @input_vocab.length,
12
+ target_vocab_size: @target_vocab.length,
13
+ }))
14
+ else
15
+ raise ArgumentError, 'Only transformer model is currently supported'
16
+ end
17
+
18
+ @model.load_state_dict(Torch.load(model_file))
19
+ @model.eval
20
+ end
21
+
22
+ def translate(phrase, max_seq_length: 100)
23
+ input = ['<sos>'] + phrase.chars + ['<eos>']
24
+ input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
25
+ output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
26
+ src_key_padding_mask = input.t.eq(1)
27
+
28
+ max_seq_length.times do |i|
29
+ tgt_key_padding_mask = output.t.eq(1)
30
+ tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
31
+ opts = {
32
+ tgt_mask: tgt_mask,
33
+ src_key_padding_mask: src_key_padding_mask,
34
+ tgt_key_padding_mask: tgt_key_padding_mask,
35
+ memory_key_padding_mask: src_key_padding_mask,
36
+ }
37
+ prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
38
+ break if @target_vocab.itos[prediction[i]] == '<eos>'
39
+ output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
40
+ end
41
+
42
+ puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
43
+ end
44
+
45
+ private
46
+ def load_vocabs
47
+ @input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
48
+ @target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,3 @@
1
+ module Secryst
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,88 @@
1
+ module Secryst
2
+ class Vocab
3
+ UNK = "<unk>"
4
+ attr_reader :stoi, :itos, :freqs
5
+
6
+ def initialize(
7
+ counter, max_size: nil, min_freq: 1, specials: ["<unk>", "<pad>", "<sos>", "<eos>"],
8
+ vectors: nil, unk_init: nil, vectors_cache: nil, specials_first: true
9
+ )
10
+
11
+ @freqs = counter
12
+ counter = counter.dup
13
+ min_freq = [min_freq, 1].max
14
+
15
+ @itos = []
16
+ @unk_index = nil
17
+
18
+ if specials_first
19
+ @itos = specials
20
+ # only extend max size if specials are prepended
21
+ max_size += specials.size if max_size
22
+ end
23
+
24
+ # frequencies of special tokens are not counted when building vocabulary
25
+ # in frequency order
26
+ specials.each do |tok|
27
+ counter.delete(tok)
28
+ end
29
+
30
+ # sort by frequency, then alphabetically
31
+ words_and_frequencies = counter.sort_by { |k, v| [-v, k] }
32
+
33
+ words_and_frequencies.each do |word, freq|
34
+ break if freq < min_freq || @itos.length == max_size
35
+ @itos << word
36
+ end
37
+
38
+ if specials.include?(UNK) # hard-coded for now
39
+ unk_index = specials.index(UNK) # position in list
40
+ # account for ordering of specials, set variable
41
+ @unk_index = specials_first ? unk_index : @itos.length + unk_index
42
+ @stoi = Hash.new(@unk_index)
43
+ else
44
+ @stoi = {}
45
+ end
46
+
47
+ if !specials_first
48
+ @itos.concat(specials)
49
+ end
50
+
51
+ # stoi is simply a reverse dict for itos
52
+ @itos.each_with_index do |tok, i|
53
+ @stoi[tok] = i
54
+ end
55
+
56
+ @vectors = nil
57
+ if !vectors.nil?
58
+ # self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
59
+ raise "Not implemented yet"
60
+ else
61
+ raise "Failed assertion" unless unk_init.nil?
62
+ raise "Failed assertion" unless vectors_cache.nil?
63
+ end
64
+ end
65
+
66
+ def [](token)
67
+ @stoi.fetch(token, @stoi.fetch(UNK))
68
+ end
69
+
70
+ def length
71
+ @itos.length
72
+ end
73
+ alias_method :size, :length
74
+
75
+ def self.build_vocab_from_iterator(iterator)
76
+ counter = Hash.new(0)
77
+ i = 0
78
+ iterator.each do |tokens|
79
+ tokens.each do |token|
80
+ counter[token] += 1
81
+ end
82
+ i += 1
83
+ puts "Processed #{i}" if i % 10000 == 0
84
+ end
85
+ Vocab.new(counter)
86
+ end
87
+ end
88
+ end
metadata ADDED
@@ -0,0 +1,138 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: secryst-trainer
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - project_contibutors
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2020-10-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: torch-rb
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '0.4'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.4'
27
+ - !ruby/object:Gem::Dependency
28
+ name: numo
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '0.1'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '0.1'
41
+ - !ruby/object:Gem::Dependency
42
+ name: numo-linalg
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '0.1'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '0.1'
55
+ - !ruby/object:Gem::Dependency
56
+ name: secryst
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - '='
60
+ - !ruby/object:Gem::Version
61
+ version: 0.1.0
62
+ type: :runtime
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - '='
67
+ - !ruby/object:Gem::Version
68
+ version: 0.1.0
69
+ - !ruby/object:Gem::Dependency
70
+ name: rake
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: rspec
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
97
+ description:
98
+ email:
99
+ executables: []
100
+ extensions: []
101
+ extra_rdoc_files: []
102
+ files:
103
+ - README.adoc
104
+ - lib/secryst-trainer.rb
105
+ - lib/secryst.rb
106
+ - lib/secryst/clip_grad_norm.rb
107
+ - lib/secryst/multi_head_attention_forward.rb
108
+ - lib/secryst/multihead_attention.rb
109
+ - lib/secryst/trainer.rb
110
+ - lib/secryst/transformer.rb
111
+ - lib/secryst/translator.rb
112
+ - lib/secryst/version.rb
113
+ - lib/secryst/vocab.rb
114
+ homepage: https://github.com/secryst/secryst
115
+ licenses:
116
+ - BSD-2-Clause
117
+ metadata: {}
118
+ post_install_message:
119
+ rdoc_options: []
120
+ require_paths:
121
+ - lib
122
+ required_ruby_version: !ruby/object:Gem::Requirement
123
+ requirements:
124
+ - - ">="
125
+ - !ruby/object:Gem::Version
126
+ version: '2.7'
127
+ required_rubygems_version: !ruby/object:Gem::Requirement
128
+ requirements:
129
+ - - ">="
130
+ - !ruby/object:Gem::Version
131
+ version: '0'
132
+ requirements: []
133
+ rubygems_version: 3.1.2
134
+ signing_key:
135
+ specification_version: 4
136
+ summary: A seq2seq transformer suited for transliteration. Written in Ruby. Includes
137
+ packages for training models
138
+ test_files: []