secryst 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,51 @@
1
+ module Secryst
2
+ class Translator
3
+ def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
4
+ @device = "cpu"
5
+ @vocabs_dir = vocabs_dir
6
+
7
+ load_vocabs
8
+
9
+ if model == 'transformer'
10
+ @model = Secryst::Transformer.new(hyperparameters.merge({
11
+ input_vocab_size: @input_vocab.length,
12
+ target_vocab_size: @target_vocab.length,
13
+ }))
14
+ else
15
+ raise ArgumentError, 'Only transformer model is currently supported'
16
+ end
17
+
18
+ @model.load_state_dict(Torch.load(model_file))
19
+ @model.eval
20
+ end
21
+
22
+ def translate(phrase, max_seq_length: 100)
23
+ input = ['<sos>'] + phrase.chars + ['<eos>']
24
+ input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
25
+ output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
26
+ src_key_padding_mask = input.t.eq(1)
27
+
28
+ max_seq_length.times do |i|
29
+ tgt_key_padding_mask = output.t.eq(1)
30
+ tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
31
+ opts = {
32
+ tgt_mask: tgt_mask,
33
+ src_key_padding_mask: src_key_padding_mask,
34
+ tgt_key_padding_mask: tgt_key_padding_mask,
35
+ memory_key_padding_mask: src_key_padding_mask,
36
+ }
37
+ prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
38
+ break if @target_vocab.itos[prediction[i]] == '<eos>'
39
+ output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
40
+ end
41
+
42
+ puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
43
+ end
44
+
45
+ private
46
+ def load_vocabs
47
+ @input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
48
+ @target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,3 @@
1
+ module Secryst
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,88 @@
1
+ module Secryst
2
+ class Vocab
3
+ UNK = "<unk>"
4
+ attr_reader :stoi, :itos, :freqs
5
+
6
+ def initialize(
7
+ counter, max_size: nil, min_freq: 1, specials: ["<unk>", "<pad>", "<sos>", "<eos>"],
8
+ vectors: nil, unk_init: nil, vectors_cache: nil, specials_first: true
9
+ )
10
+
11
+ @freqs = counter
12
+ counter = counter.dup
13
+ min_freq = [min_freq, 1].max
14
+
15
+ @itos = []
16
+ @unk_index = nil
17
+
18
+ if specials_first
19
+ @itos = specials
20
+ # only extend max size if specials are prepended
21
+ max_size += specials.size if max_size
22
+ end
23
+
24
+ # frequencies of special tokens are not counted when building vocabulary
25
+ # in frequency order
26
+ specials.each do |tok|
27
+ counter.delete(tok)
28
+ end
29
+
30
+ # sort by frequency, then alphabetically
31
+ words_and_frequencies = counter.sort_by { |k, v| [-v, k] }
32
+
33
+ words_and_frequencies.each do |word, freq|
34
+ break if freq < min_freq || @itos.length == max_size
35
+ @itos << word
36
+ end
37
+
38
+ if specials.include?(UNK) # hard-coded for now
39
+ unk_index = specials.index(UNK) # position in list
40
+ # account for ordering of specials, set variable
41
+ @unk_index = specials_first ? unk_index : @itos.length + unk_index
42
+ @stoi = Hash.new(@unk_index)
43
+ else
44
+ @stoi = {}
45
+ end
46
+
47
+ if !specials_first
48
+ @itos.concat(specials)
49
+ end
50
+
51
+ # stoi is simply a reverse dict for itos
52
+ @itos.each_with_index do |tok, i|
53
+ @stoi[tok] = i
54
+ end
55
+
56
+ @vectors = nil
57
+ if !vectors.nil?
58
+ # self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
59
+ raise "Not implemented yet"
60
+ else
61
+ raise "Failed assertion" unless unk_init.nil?
62
+ raise "Failed assertion" unless vectors_cache.nil?
63
+ end
64
+ end
65
+
66
+ def [](token)
67
+ @stoi.fetch(token, @stoi.fetch(UNK))
68
+ end
69
+
70
+ def length
71
+ @itos.length
72
+ end
73
+ alias_method :size, :length
74
+
75
+ def self.build_vocab_from_iterator(iterator)
76
+ counter = Hash.new(0)
77
+ i = 0
78
+ iterator.each do |tokens|
79
+ tokens.each do |token|
80
+ counter[token] += 1
81
+ end
82
+ i += 1
83
+ puts "Processed #{i}" if i % 10000 == 0
84
+ end
85
+ Vocab.new(counter)
86
+ end
87
+ end
88
+ end
metadata ADDED
@@ -0,0 +1,95 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: secryst
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - project_contibutors
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2020-10-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: torch-rb
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '0.4'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.4'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rake
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rspec
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ description: Seq2seq transformer for transliteration in Ruby.
56
+ email:
57
+ executables: []
58
+ extensions: []
59
+ extra_rdoc_files: []
60
+ files:
61
+ - README.adoc
62
+ - lib/secryst-trainer.rb
63
+ - lib/secryst.rb
64
+ - lib/secryst/clip_grad_norm.rb
65
+ - lib/secryst/multi_head_attention_forward.rb
66
+ - lib/secryst/multihead_attention.rb
67
+ - lib/secryst/trainer.rb
68
+ - lib/secryst/transformer.rb
69
+ - lib/secryst/translator.rb
70
+ - lib/secryst/version.rb
71
+ - lib/secryst/vocab.rb
72
+ homepage: https://github.com/secryst/secryst
73
+ licenses:
74
+ - BSD-2-Clause
75
+ metadata: {}
76
+ post_install_message:
77
+ rdoc_options: []
78
+ require_paths:
79
+ - lib
80
+ required_ruby_version: !ruby/object:Gem::Requirement
81
+ requirements:
82
+ - - ">="
83
+ - !ruby/object:Gem::Version
84
+ version: '2.7'
85
+ required_rubygems_version: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ requirements: []
91
+ rubygems_version: 3.1.2
92
+ signing_key:
93
+ specification_version: 4
94
+ summary: Seq2seq transformer for transliteration in Ruby.
95
+ test_files: []