secryst 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,51 @@
1
+ module Secryst
2
+ class Translator
3
+ def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
4
+ @device = "cpu"
5
+ @vocabs_dir = vocabs_dir
6
+
7
+ load_vocabs
8
+
9
+ if model == 'transformer'
10
+ @model = Secryst::Transformer.new(hyperparameters.merge({
11
+ input_vocab_size: @input_vocab.length,
12
+ target_vocab_size: @target_vocab.length,
13
+ }))
14
+ else
15
+ raise ArgumentError, 'Only transformer model is currently supported'
16
+ end
17
+
18
+ @model.load_state_dict(Torch.load(model_file))
19
+ @model.eval
20
+ end
21
+
22
+ def translate(phrase, max_seq_length: 100)
23
+ input = ['<sos>'] + phrase.chars + ['<eos>']
24
+ input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
25
+ output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
26
+ src_key_padding_mask = input.t.eq(1)
27
+
28
+ max_seq_length.times do |i|
29
+ tgt_key_padding_mask = output.t.eq(1)
30
+ tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
31
+ opts = {
32
+ tgt_mask: tgt_mask,
33
+ src_key_padding_mask: src_key_padding_mask,
34
+ tgt_key_padding_mask: tgt_key_padding_mask,
35
+ memory_key_padding_mask: src_key_padding_mask,
36
+ }
37
+ prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
38
+ break if @target_vocab.itos[prediction[i]] == '<eos>'
39
+ output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
40
+ end
41
+
42
+ puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
43
+ end
44
+
45
+ private
46
+ def load_vocabs
47
+ @input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
48
+ @target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
49
+ end
50
+ end
51
+ end
@@ -0,0 +1,3 @@
1
+ module Secryst
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,88 @@
1
+ module Secryst
2
+ class Vocab
3
+ UNK = "<unk>"
4
+ attr_reader :stoi, :itos, :freqs
5
+
6
+ def initialize(
7
+ counter, max_size: nil, min_freq: 1, specials: ["<unk>", "<pad>", "<sos>", "<eos>"],
8
+ vectors: nil, unk_init: nil, vectors_cache: nil, specials_first: true
9
+ )
10
+
11
+ @freqs = counter
12
+ counter = counter.dup
13
+ min_freq = [min_freq, 1].max
14
+
15
+ @itos = []
16
+ @unk_index = nil
17
+
18
+ if specials_first
19
+ @itos = specials
20
+ # only extend max size if specials are prepended
21
+ max_size += specials.size if max_size
22
+ end
23
+
24
+ # frequencies of special tokens are not counted when building vocabulary
25
+ # in frequency order
26
+ specials.each do |tok|
27
+ counter.delete(tok)
28
+ end
29
+
30
+ # sort by frequency, then alphabetically
31
+ words_and_frequencies = counter.sort_by { |k, v| [-v, k] }
32
+
33
+ words_and_frequencies.each do |word, freq|
34
+ break if freq < min_freq || @itos.length == max_size
35
+ @itos << word
36
+ end
37
+
38
+ if specials.include?(UNK) # hard-coded for now
39
+ unk_index = specials.index(UNK) # position in list
40
+ # account for ordering of specials, set variable
41
+ @unk_index = specials_first ? unk_index : @itos.length + unk_index
42
+ @stoi = Hash.new(@unk_index)
43
+ else
44
+ @stoi = {}
45
+ end
46
+
47
+ if !specials_first
48
+ @itos.concat(specials)
49
+ end
50
+
51
+ # stoi is simply a reverse dict for itos
52
+ @itos.each_with_index do |tok, i|
53
+ @stoi[tok] = i
54
+ end
55
+
56
+ @vectors = nil
57
+ if !vectors.nil?
58
+ # self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
59
+ raise "Not implemented yet"
60
+ else
61
+ raise "Failed assertion" unless unk_init.nil?
62
+ raise "Failed assertion" unless vectors_cache.nil?
63
+ end
64
+ end
65
+
66
+ def [](token)
67
+ @stoi.fetch(token, @stoi.fetch(UNK))
68
+ end
69
+
70
+ def length
71
+ @itos.length
72
+ end
73
+ alias_method :size, :length
74
+
75
+ def self.build_vocab_from_iterator(iterator)
76
+ counter = Hash.new(0)
77
+ i = 0
78
+ iterator.each do |tokens|
79
+ tokens.each do |token|
80
+ counter[token] += 1
81
+ end
82
+ i += 1
83
+ puts "Processed #{i}" if i % 10000 == 0
84
+ end
85
+ Vocab.new(counter)
86
+ end
87
+ end
88
+ end
metadata ADDED
@@ -0,0 +1,95 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: secryst
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - project_contibutors
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2020-10-05 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: torch-rb
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '0.4'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.4'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rake
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rspec
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ description: Seq2seq transformer for transliteration in Ruby.
56
+ email:
57
+ executables: []
58
+ extensions: []
59
+ extra_rdoc_files: []
60
+ files:
61
+ - README.adoc
62
+ - lib/secryst-trainer.rb
63
+ - lib/secryst.rb
64
+ - lib/secryst/clip_grad_norm.rb
65
+ - lib/secryst/multi_head_attention_forward.rb
66
+ - lib/secryst/multihead_attention.rb
67
+ - lib/secryst/trainer.rb
68
+ - lib/secryst/transformer.rb
69
+ - lib/secryst/translator.rb
70
+ - lib/secryst/version.rb
71
+ - lib/secryst/vocab.rb
72
+ homepage: https://github.com/secryst/secryst
73
+ licenses:
74
+ - BSD-2-Clause
75
+ metadata: {}
76
+ post_install_message:
77
+ rdoc_options: []
78
+ require_paths:
79
+ - lib
80
+ required_ruby_version: !ruby/object:Gem::Requirement
81
+ requirements:
82
+ - - ">="
83
+ - !ruby/object:Gem::Version
84
+ version: '2.7'
85
+ required_rubygems_version: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ requirements: []
91
+ rubygems_version: 3.1.2
92
+ signing_key:
93
+ specification_version: 4
94
+ summary: Seq2seq transformer for transliteration in Ruby.
95
+ test_files: []