basic-lstm 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (6) hide show
  1. checksums.yaml +7 -0
  2. data/lib/CELL.rb +261 -0
  3. data/lib/DICTIONARY.rb +123 -0
  4. data/lib/ENCODER.rb +132 -0
  5. data/lib/NETWORK.rb +163 -0
  6. metadata +46 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 77cfe659ce2d038f9f415fee83f260f7b7923933fa9275536f490066f26c96af
4
+ data.tar.gz: d5f35dd0a7625c278f9c17230e5fa7b31f8495405d22302786ef7f831877c692
5
+ SHA512:
6
+ metadata.gz: 2c854837f63551d21cde6a45d57855cda75a1b8b8e4db14d51ce540bbc24e127995793046385c56e9e05a9adaf12b7b2fce1fc05994d90007f39409c9056dd65
7
+ data.tar.gz: 57d4f463984171afd4a90d267d430b5226c649edcc08104634c4ab803d46cf8fd5c2428653c888ad3ab6d8724eba646bfec4c118c3bea2e0f64a240ff673db4f
data/lib/CELL.rb ADDED
@@ -0,0 +1,261 @@
1
+ #!/usr/bin/env ruby
2
+ =begin
3
+ ██╗░░░░░░██████╗████████╗███╗░░░███╗
4
+ ██║░░░░░██╔════╝╚══██╔══╝████╗░████║
5
+ ██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
6
+ ██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
7
+ ███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
8
+ ╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
9
+ Created: 14/09/2023.
10
+ Version: 1.0.0
11
+ Author: Ryan May
12
+ This is a text generator LSTM developed in Ruby.
13
+ In this project I gain more of an understanding of the training and creation of LSTM networks.
14
+ This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
15
+ Expansion on this project will be utilised in my Honours project.
16
+ =end
17
+ require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
18
+ require 'numo/narray'
19
+ include Numo
20
+ =begin
21
+ ██╗░░░░░░██████╗████████╗███╗░░░███╗ ░█████╗░███████╗██╗░░░░░██╗░░░░░
22
+ ██║░░░░░██╔════╝╚══██╔══╝████╗░████║ ██╔══██╗██╔════╝██║░░░░░██║░░░░░
23
+ ██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║ ██║░░╚═╝█████╗░░██║░░░░░██║░░░░░
24
+ ██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║ ██║░░██╗██╔══╝░░██║░░░░░██║░░░░░
25
+ ███████╗██████╔╝░░░██║░░░██║░╚═╝░██║ ╚█████╔╝███████╗███████╗███████╗
26
+ ╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝ ░╚════╝░╚══════╝╚══════╝╚══════╝
27
+ =end
28
+ class LSTM_CELL
29
+ # Constructor
30
+ # This initialisation function configures the LSTM Cell by
31
+ # constructing the required arrays (Numo::NArrays)
32
+ def init(alpha, sz, terminal_output=nil)
33
+ @sz = sz
34
+ @@Alpha = alpha
35
+ ### Weight matrices with regard to inputs
36
+ @Wg = DFloat.new(1, @sz).rand
37
+ @Wi = DFloat.new(1, @sz).rand
38
+ @Wf = DFloat.new(1, @sz).rand
39
+ @Wo = DFloat.new(1, @sz).rand
40
+ ### Weight matrices with regard to cell previous cell output
41
+ @Ug = DFloat.new(1, @sz).rand
42
+ @Ui = DFloat.new(1, @sz).rand
43
+ @Uf = DFloat.new(1, @sz).rand
44
+ @Uo = DFloat.new(1, @sz).rand
45
+ ### Bias
46
+ @Bg = DFloat.new(1, @sz).rand
47
+ @Bi = DFloat.new(1, @sz).rand
48
+ @Bf = DFloat.new(1, @sz).rand
49
+ @Bo = DFloat.new(1, @sz).rand
50
+ ### Cell input
51
+ @Xt = DFloat.zeros(1, @sz)
52
+ @Yt = DFloat.zeros(1, @sz)
53
+ ### Cell state
54
+ @Hprev = DFloat.zeros(1, @sz) # Previous cell output
55
+ @Cprev = DFloat.zeros(1, @sz) # Previous cell state
56
+ @Ct = DFloat.zeros(1, @sz) # Cell state
57
+ @Ht = DFloat.zeros(1, @sz) # Cell output
58
+ ### Internal Cell gates
59
+ @F = DFloat.new(1, @sz).rand
60
+ @I = DFloat.new(1, @sz).rand
61
+ @O = DFloat.new(1, @sz).rand
62
+ @G = DFloat.new(1, @sz).rand
63
+ ### BPTT variables
64
+ ### Change related to cell input
65
+ @dWi = DFloat.zeros(1, @sz)
66
+ @dWf = DFloat.zeros(1, @sz)
67
+ @dWo = DFloat.zeros(1, @sz)
68
+ @dWg = DFloat.zeros(1, @sz)
69
+ ### Change related to cell state
70
+ @dUi = DFloat.zeros(1, @sz)
71
+ @dUf = DFloat.zeros(1, @sz)
72
+ @dUo = DFloat.zeros(1, @sz)
73
+ @dUg = DFloat.zeros(1, @sz)
74
+ ### Change required to Bias
75
+ @dBi = DFloat.zeros(1, @sz)
76
+ @dBf = DFloat.zeros(1, @sz)
77
+ @dBo = DFloat.zeros(1, @sz)
78
+ @dBg = DFloat.zeros(1, @sz)
79
+ end
80
+ # Hadamard product
81
+ def hp(nArr1, nArr2)
82
+ nArr3 = DFloat.zeros(nArr1.shape())
83
+ dims = nArr1.shape()
84
+ for i in 0...(dims[0])
85
+ for j in 0...(dims[1])
86
+ nArr3[i,j] = (nArr1[i,j] * nArr2[i,j])
87
+ end
88
+ end
89
+ return nArr3
90
+ end
91
+ # Forward Propagation
92
+ def forwardPropagation()
93
+ ### Calculating gate values
94
+ @Zf = @Wf.dot(@Xt) + @Uf.dot(@Hprev) + @Bf
95
+ @F = sigmoidVector(@Zf)
96
+ @Zi = @Wi.dot(@Xt) + @Ui.dot(@Hprev) + @Bi
97
+ @I = sigmoidVector(@Zi)
98
+ @Zo = @Wo.dot(@Xt) + @Uo.dot(@Hprev) + @Bo
99
+ @O = sigmoidVector(@Zo)
100
+ @Zg = @Wg.dot(@Xt) + @Ug.dot(@Hprev) + @Bg
101
+ @G = NMath.tanh(@Zg)
102
+ ### Calculating cell states
103
+ @Ct = hp(@F, @Cprev) + hp(@I, @G)
104
+ @Ht = hp(@O, NMath.tanh(@Ct))
105
+ end
106
+ # Back propagation
107
+ def backwardPropagation(top_diff_h, top_diff_c)
108
+ dC = hp(@O, top_diff_h) + top_diff_c
109
+ dO = hp(@Ct, top_diff_h)
110
+ dI = hp(@G, dC)
111
+ dG = hp(@I, dC)
112
+ dF = hp(@Cprev, dC)
113
+
114
+ dIinput = hp(sigmoidDer(@I), dI)
115
+ dFinput = hp(sigmoidDer(@F), dF)
116
+ dOinput = hp(sigmoidDer(@O), dO)
117
+ dGinput = hp(tanhDer(@G), dG)
118
+ ### Calculating change required to input weight matrices
119
+ @dWi -= hp(dIinput, @Xt)
120
+ @dWf -= hp(dFinput, @Xt)
121
+ @dWo -= hp(dOinput, @Xt)
122
+ @dWg -= hp(dGinput, @Xt)
123
+ #p @dWi.to_a
124
+ ### Calculating change required to cell state weight matrices
125
+ @dUi += hp(dIinput, @Hprev)
126
+ @dUf += hp(dFinput, @Hprev)
127
+ @dUo += hp(dOinput, @Hprev)
128
+ @dUg += hp(dGinput, @Hprev)
129
+ ### Calculating change required to bias vectors
130
+ @dBi += dIinput
131
+ @dBf += dFinput
132
+ @dBo += dOinput
133
+ @dBg += dGinput
134
+ ### Calculating change required to HPrev
135
+ dHprev = DFloat.zeros(1, @sz)
136
+ dHprev += @Wi.transpose.dot(dIinput)
137
+ dHprev += @Wf.transpose.dot(dFinput)
138
+ dHprev += @Wo.transpose.dot(dOinput)
139
+ dHprev += @Wg.transpose.dot(dGinput)
140
+ ### Calculating and returning bottoms
141
+ @bottom_diff_c = hp(dC, @F)
142
+ @bottom_diff_h = dHprev
143
+ end
144
+ def applyWeightChange()
145
+ #puts ">> " + @dWg[0,0].to_s
146
+ ### Apply input weight changes
147
+ @Wi += @@Alpha * @dWi #@@Alpha
148
+ @Wf += @@Alpha * @dWf
149
+ @Wo += @@Alpha * @dWo
150
+ @Wg += @@Alpha * @dWg
151
+ ### Apply Cell output weight changes
152
+ @Ui += @@Alpha * @dUi
153
+ @Uf += @@Alpha * @dUf
154
+ @Uo += @@Alpha * @dUo
155
+ @Ug += @@Alpha * @dUg
156
+ ### Apply bias changes
157
+ @Bi += @@Alpha * @dBi
158
+ @Bf += @@Alpha * @dBf
159
+ @Bo += @@Alpha * @dBo
160
+ @Bg += @@Alpha * @dBg
161
+ #puts "<< " + @Wg[0,0].to_s
162
+ ### Change related to cell input
163
+ @dWi = DFloat.zeros(1, @sz)
164
+ @dWf = DFloat.zeros(1, @sz)
165
+ @dWo = DFloat.zeros(1, @sz)
166
+ @dWg = DFloat.zeros(1, @sz)
167
+ ### Change related to cell state
168
+ @dUi = DFloat.zeros(1, @sz)
169
+ @dUf = DFloat.zeros(1, @sz)
170
+ @dUo = DFloat.zeros(1, @sz)
171
+ @dUg = DFloat.zeros(1, @sz)
172
+ ### Change required to Bias
173
+ @dBi = DFloat.zeros(1, @sz)
174
+ @dBf = DFloat.zeros(1, @sz)
175
+ @dBo = DFloat.zeros(1, @sz)
176
+ @dBg = DFloat.zeros(1, @sz)
177
+ end
178
+ # Transfer functions
179
+ def tanhDer(v)
180
+ #input vector must be a horisontal vector
181
+ output_vector = DFloat.zeros(v.shape[0], v.shape[1])
182
+ for i in 0...v.shape[0]
183
+ for j in 0...v.shape[1]
184
+ output_vector[i,j] = 1 - (v[i,j]**2)
185
+ end
186
+ end
187
+ return output_vector
188
+ end
189
+ def sigmoidDer(v)
190
+ #input vector must be a horisontal vector
191
+ output_vector = DFloat.zeros(v.shape[0], v.shape[1])
192
+ for i in 0...v.shape[0]
193
+ for j in 0...v.shape[1]
194
+ output_vector[i,j] = v[i,j] * (1 - v[i,j])
195
+ end
196
+ end
197
+ return output_vector
198
+ end
199
+ def sigmoidVector(v)
200
+ #input vector must be a horisontal vector
201
+ output_vector = DFloat.zeros(v.shape[0], v.shape[1])
202
+ for i in 0...v.shape[0]
203
+ for j in 0...v.shape[1]
204
+ output_vector[i,j] = sigmoid(v[i,j])
205
+ end
206
+ end
207
+ return output_vector
208
+ end
209
+ def sigmoid(value_)
210
+ output_value = 1 / (1 + Math.exp(-1 * value_))
211
+ return output_value
212
+ end
213
+ # Getters and Setters
214
+ def setXt(xt)
215
+ if xt.shape()[0] != 1
216
+ puts "Size of xt: " + xt.shape().to_s
217
+ raise "Setting @Xt of LSTM cell raised an incorrect dimension error"
218
+ return
219
+ end
220
+ @Xt = xt
221
+ end
222
+ def setYt(yt)
223
+ if yt.shape()[0] != 1
224
+ puts "Size of yt: " + yt.shape().to_s
225
+ raise "Setting @Yt of LSTM cell raised an incorrect dimension error"
226
+ return
227
+ end
228
+ @Yt = yt
229
+ end
230
+ def setHprev(hp)
231
+ if hp.shape()[0] != 1
232
+ puts "Size of hp: " + hp.shape().to_s
233
+ raise "Setting @Hprev of LSTM cell raised an incorrect dimension error"
234
+ return
235
+ end
236
+ @Hprev = hp
237
+ end
238
+ def setCprev(cp)
239
+ @Cprev = cp
240
+ if cp.shape()[0] != 1
241
+ puts "Size of cp: " + cp.shape().to_s
242
+ raise "Setting @Cprev of LSTM cell raised an incorrect dimension error"
243
+ return
244
+ end
245
+ end
246
+ def getBottomDeltaHt()
247
+ return @bottom_diff_h
248
+ end
249
+ def getBottomDeltaCt()
250
+ return @bottom_diff_c
251
+ end
252
+ def getYt()
253
+ return @Yt
254
+ end
255
+ def getHt()
256
+ return @Ht
257
+ end
258
+ def getCt
259
+ return @Ct
260
+ end
261
+ end
data/lib/DICTIONARY.rb ADDED
@@ -0,0 +1,123 @@
1
+ #!/usr/bin/env ruby
2
+ =begin
3
+ Created: 14/09/2023.
4
+ Version: 1.0.0
5
+ =end
6
+ require 'numo/narray'
7
+ include Numo
8
+ =begin
9
+ Dictionary
10
+ =end
11
+ class DICTIONARY < ENCODER
12
+ def init(string, x_dim)
13
+ @x_dim = x_dim
14
+ @frequencyHash = wordFrequency(string)
15
+ @stringToEncoding, @encodingToString = hotEncodeVocabulary(@frequencyHash, x_dim)
16
+ end
17
+ def getFrequencyHash()
18
+ return @frequencyHash
19
+ end
20
+ def getEncodingHashs()
21
+ return @stringToEncoding, @encodingToString
22
+ end
23
+ def wordFrequency(words)
24
+ if words.is_a?(String)
25
+ words = words.split
26
+ end
27
+ frequency = {}
28
+ entries = 0
29
+ for i in 0...words.length()
30
+ if frequency[words[i]] == nil
31
+ frequency[words[i]] = 1
32
+ entries += 1
33
+ else
34
+ frequency[words[i]] = frequency[words[i]] + 1
35
+ end
36
+ end
37
+ return frequency.sort_by {|k, v| v}.reverse
38
+ end
39
+ def hotEncodeVocabulary(hash, maxEntries)
40
+ stringToEncoding = {}
41
+ encodingToString = {}
42
+ binary = 0b1
43
+ hash.each do |key, value|
44
+ stringToEncoding["#{key}"] = binary
45
+ encodingToString[binary] = "#{key}"
46
+ binary = binary << 1
47
+ if (binary) == (0b1 << maxEntries)
48
+ break;
49
+ end
50
+ end
51
+ return stringToEncoding, encodingToString
52
+ end
53
+ def readFileDataArray(fileDataArray, fileIndex, start, input_len, predict_len)
54
+ fileWordArray = (fileDataArray[fileIndex]).split
55
+ return fileWordArray[start, input_len], fileWordArray[(start+predict_len), (input_len)]
56
+ end
57
+ def encodeArray(wordArray, hash=@stringToEncoding)
58
+ if wordArray.instance_of? String
59
+ wordArray = wordArray.split()
60
+ end
61
+ array = Int32.zeros(wordArray.length(), @x_dim)
62
+ for i in 0...wordArray.length()
63
+ word = wordArray[i]
64
+ encoding = hash[word]
65
+ if encoding != nil
66
+ binaryArray = binaryToArray(encoding, @x_dim)
67
+ array[i, true] = binaryArray
68
+ end
69
+ end
70
+ return array
71
+ end
72
+ def binaryToArray(binary, x_dim)
73
+ binaryFormatted = ((0b1 << (x_dim+1) ^ binary*2)).to_s(2)[1, x_dim]
74
+ binaryArray = DFloat.zeros(1, x_dim)
75
+ binaryArray = binaryFormatted.split("").map(&:to_i)
76
+ return binaryArray
77
+ end
78
+ def arrayToBinary(array)
79
+ binaryString = array.to_a.join("")
80
+ binary = binaryString.to_i(2)
81
+ return binary
82
+ end
83
+ def decodeArray(encodedArray, hash=@encodingToString)
84
+ output = ""
85
+ for i in 0...encodedArray.shape()[0]
86
+ decodedWord = hash[arrayToBinary(encodedArray[i, true])]
87
+ if decodedWord != nil
88
+ output += decodedWord + " "
89
+ else
90
+ output += "@ "
91
+ end
92
+ end
93
+ return output
94
+ end
95
+ def decodeArrayByMaximum(encodedArray, hash=@encodingToString)
96
+ output = ""
97
+ for i in 0..encodedArray.shape()[0]-1
98
+ if encodedArray[i, true].to_a.max != 0
99
+ max_element = encodedArray[i, true].to_a.each_with_index.max[1]
100
+ encodedArrayBW = Int32.zeros(encodedArray[i, true].shape())
101
+ encodedArrayBW[max_element] = 1
102
+ decodedWord = hash[arrayToBinary(encodedArrayBW)]
103
+ if decodedWord != nil
104
+ output += decodedWord + " "
105
+ else
106
+ output += "@ "
107
+ end
108
+ else
109
+ output += "$ "
110
+ end
111
+ end
112
+ return output
113
+ end
114
+ def viewHash(hash)
115
+ puts "key -> value"
116
+ entries = 0
117
+ hash.each do |key, value|
118
+ puts key.to_s + " -> " + value.to_s
119
+ entries += 1
120
+ end
121
+ puts "Entires in hash: " + entries.to_s
122
+ end
123
+ end
data/lib/ENCODER.rb ADDED
@@ -0,0 +1,132 @@
1
+ #!/usr/bin/env ruby
2
+ =begin
3
+ ██╗░░░░░░██████╗████████╗███╗░░░███╗
4
+ ██║░░░░░██╔════╝╚══██╔══╝████╗░████║
5
+ ██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
6
+ ██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
7
+ ███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
8
+ ╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
9
+ Created: 14/09/2023.
10
+ Version: 1.0.0
11
+ Author: Ryan May
12
+ This is a text generator LSTM developed in Ruby.
13
+ In this project I gain more of an understanding of the training and creation of LSTM networks.
14
+ This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
15
+ Expansion on this project will be utilised in my Honours project.
16
+ =end
17
+ require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
18
+ require 'numo/narray'
19
+ include Numo
20
+ =begin
21
+ ███████╗███╗░░██╗░█████╗░░█████╗░██████╗░███████╗██████╗░
22
+ ██╔════╝████╗░██║██╔══██╗██╔══██╗██╔══██╗██╔════╝██╔══██╗
23
+ █████╗░░██╔██╗██║██║░░╚═╝██║░░██║██║░░██║█████╗░░██████╔╝
24
+ ██╔══╝░░██║╚████║██║░░██╗██║░░██║██║░░██║██╔══╝░░██╔══██╗
25
+ ███████╗██║░╚███║╚█████╔╝╚█████╔╝██████╔╝███████╗██║░░██║
26
+ ╚══════╝╚═╝░░╚══╝░╚════╝░░╚════╝░╚═════╝░╚══════╝╚═╝░░╚═╝
27
+ The encoder class is used to translate input characters and sentances into one-hot encoded vectors.
28
+ Additionally, the encoder class contains functions to handle file reading, regex filtering, and conversions
29
+ between 'Matrix' and 'NArray' types.
30
+ =end
31
+ class ENCODER
32
+ def init(charmatrix = nil)
33
+ if charmatrix == nil
34
+ @Length = 56
35
+ @SelectionMatrix = Matrix.build(1,56) {0} # 1 row 32 columns
36
+ @CharMatrix = Array['.',',',"\s",'!',
37
+ 'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
38
+ 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']
39
+ else
40
+ @Length = charmatrix.length()
41
+ @SelectionMatrix = Matrix.build(1, charmatrix.length()) { 0 }
42
+ @CharMatrix = charmatrix
43
+ end
44
+ end
45
+ def matrixToNArray(matrix)
46
+ nArray = DFloat[*matrix.to_a]
47
+ return nArray
48
+ end
49
+ def nArrayToMatrix(nArray)
50
+ #puts nArray.to_a.map(&:inspect)
51
+ matrix = Matrix[*nArray.to_a]
52
+ #puts matrix.to_a.map(&:inspect)
53
+ return matrix
54
+ end
55
+ def listFilesInDirectory(path)
56
+ fileNameArray = Dir.entries(path).reject {|f| File.directory?(f) || f[0].include?('.')}
57
+ return fileNameArray
58
+ end
59
+ def readPlaintextFilesInDirectory(path)
60
+ fileDataArray = Array.new(0) { "" }
61
+ # Get names of all files in a folder.
62
+ fileNameArray = Dir.entries(path).reject {|f| File.directory?(f) || f[0].include?('.')}
63
+ # itterate through all files in Array
64
+ for index in 0...fileNameArray.length()
65
+ puts "reading file " + path + "/" + fileNameArray[index]
66
+ fileData = readPlaintextfile(path + "/" + fileNameArray[index])
67
+ fileDataArray = fileDataArray << fileData
68
+ end
69
+ return fileDataArray
70
+ end
71
+ def readFileDataArray(fileDataArray, index, start, input_len, predict_len)
72
+ combination = hotEncodeSentance(filter(fileDataArray[index])[start,predict_len+input_len])
73
+ return combination
74
+ end
75
+ def readPlaintextfile(fileName)
76
+ file = File.open(fileName)
77
+ file_data = file.read
78
+ file.close
79
+ return file_data
80
+ end
81
+ def filter(sentance, regex=nil)
82
+ if regex == nil
83
+ sentance = sentance.gsub(/[^A-Za-z\., ]/, '')
84
+ else
85
+ sentance = sentance.gsub(regex, '')
86
+ end
87
+ return sentance
88
+ end
89
+ def hotEncodeSentance(sentance)
90
+ letters = sentance.split(//)
91
+ matrix = Matrix.build(0, 0){0}
92
+ for index in 0...letters.length()
93
+ charvector = hotEncodeCharacter(letters[index])
94
+ if charvector != nil
95
+ matrix = Matrix.rows(matrix.to_a << charvector.to_a)
96
+ end
97
+ end
98
+ return matrix
99
+ end
100
+ def hotDecodeSentance(matrix)
101
+ sentance = ""
102
+ for index in 0...matrix.row_count()
103
+ charvector = matrix.row(index)
104
+ char = hotDecodeCharacter(charvector)
105
+ sentance = sentance + char
106
+ end
107
+ return sentance
108
+ end
109
+ def hotEncodeCharacter(char)
110
+ # reset selection matrix
111
+ @SelectionMatrix = Matrix.build(1,@Length) {0}
112
+ index = @CharMatrix.index char
113
+ @SelectionMatrix[0,index] = 1
114
+ return @SelectionMatrix.row(0)
115
+ end
116
+ def hotDecodeCharacter(vector)
117
+ # convert vector to array so we can search for the up bit
118
+ vectorArray = vector.to_a
119
+ index = vectorArray.each_with_index.max[1]
120
+ if index != nil
121
+ return @CharMatrix[index]
122
+ else
123
+ return '@'
124
+ end
125
+ end
126
+ def stringDifferencePercent(a, b)
127
+ longer = [a.size, b.size].max
128
+ same = a.each_char.zip(b.each_char).count { |a,b| a == b }
129
+ similarity = (longer - same) / a.size.to_f
130
+ return similarity
131
+ end
132
+ end
data/lib/NETWORK.rb ADDED
@@ -0,0 +1,163 @@
1
+ #!/usr/bin/env ruby
2
+ =begin
3
+ ██╗░░░░░░██████╗████████╗███╗░░░███╗
4
+ ██║░░░░░██╔════╝╚══██╔══╝████╗░████║
5
+ ██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
6
+ ██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
7
+ ███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
8
+ ╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
9
+ Created: 14/09/2023.
10
+ Version: 1.0.0
11
+ Author: Ryan May
12
+ This is a text generator LSTM developed in Ruby.
13
+ In this project I gain more of an understanding of the training and creation of LSTM networks.
14
+ This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
15
+ Expansion on this project will be utilised in my Honours project.
16
+ =end
17
+ require_relative 'ENCODER'
18
+ require_relative 'CELL'
19
+ require_relative 'DICTIONARY'
20
+ require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
21
+ require 'numo/narray'
22
+ include Numo
23
+ =begin
24
+ ██╗░░░░░░██████╗████████╗███╗░░░███╗ ███╗░░██╗███████╗████████╗░██╗░░░░░░░██╗░█████╗░██████╗░██╗░░██╗
25
+ ██║░░░░░██╔════╝╚══██╔══╝████╗░████║ ████╗░██║██╔════╝╚══██╔══╝░██║░░██╗░░██║██╔══██╗██╔══██╗██║░██╔╝
26
+ ██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║ ██╔██╗██║█████╗░░░░░██║░░░░╚██╗████╗██╔╝██║░░██║██████╔╝█████═╝░
27
+ ██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║ ██║╚████║██╔══╝░░░░░██║░░░░░████╔═████║░██║░░██║██╔══██╗██╔═██╗░
28
+ ███████╗██████╔╝░░░██║░░░██║░╚═╝░██║ ██║░╚███║███████╗░░░██║░░░░░╚██╔╝░╚██╔╝░╚█████╔╝██║░░██║██║░╚██╗
29
+ ╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝ ╚═╝░░╚══╝╚══════╝░░░╚═╝░░░░░░╚═╝░░░╚═╝░░░╚════╝░╚═╝░░╚═╝╚═╝░░╚═╝
30
+ This class handles the single-height LSTM network. The LSTM network comprises of an array of LSTM cell objects,
31
+ an input matrix, and a target matrix.
32
+ Forward propagation for the network, and backward network propagation is implemented here.
33
+ This is seperate from the cellular-level forward and back propagation in the LSTM_CELL class.
34
+ =end
35
+ class LSTM_NETWORK
36
+ def init(nodes, x_dim, alpha, terminal_output=nil)
37
+ @dict = DICTIONARY. new
38
+ @encoder = ENCODER. new
39
+ @encoder.init()
40
+ @sz = x_dim
41
+ @@Alpha = alpha
42
+ @Nodes = nodes
43
+ ### Creating LSTM network (matrix of nodes)
44
+ @lstm_nodes = Matrix.build(1, @Nodes) { LSTM_CELL }
45
+ for i in 0...@lstm_nodes.column_count()
46
+ @lstm_nodes[0,i] = LSTM_CELL. new
47
+ @lstm_nodes[0,i].init(@@Alpha, @sz, terminal_output)
48
+ end
49
+ ## Input and target
50
+ @Input = DFloat.zeros(nodes, @sz)
51
+ @Target = DFloat.zeros(nodes, @sz)
52
+ end
53
+ def getLSTMNodes()
54
+ return @lstm_nodes
55
+ end
56
+ def setDictionary(dictionary)
57
+ @dict = dictionary
58
+ end
59
+ def setTarget(target, mode=nil)
60
+ if mode == "encoded"
61
+ @Target = target
62
+ elsif target.instance_of? String
63
+ @Target = @encoder.matrixToNArray(@encoder.hotEncodeSentance(target))
64
+ elsif target.instance_of? Array
65
+ @Target = @dict.encodeArray(target)
66
+ else
67
+ raise "Target input is not of type Array or String"
68
+ end
69
+ end
70
+ def setInput(input, mode=nil)
71
+ if mode == "encoded"
72
+ @Input = input
73
+ elsif input.instance_of? String
74
+ @Input = @encoder.matrixToNArray(@encoder.hotEncodeSentance(input))
75
+ elsif input.instance_of? Array
76
+ @Input = @dict.encodeArray(input)
77
+ else
78
+ raise "Target input is not of type Array or String"
79
+ end
80
+ end
81
+ def getTarget(mode="word_mode")
82
+ if mode == "word_mode"
83
+ return @dict.decodeArray(@Target)
84
+ elsif mode == "char_mode"
85
+ return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Target))
86
+ elsif mode == "array_mode"
87
+ return @Target
88
+ end
89
+ end
90
+ def getInput(mode="word_mode")
91
+ if mode == "word_mode"
92
+ return @dict.decodeArray(@Input)
93
+ elsif mode == "char_mode"
94
+ return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Input))
95
+ elsif mode == "array_mode"
96
+ return @Input
97
+ end
98
+ end
99
+ def getOutput(mode="word_mode")
100
+ if mode == "word_mode"
101
+ return @dict.decodeArrayByMaximum(@Output)
102
+ elsif mode == "char_mode"
103
+ return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Output))
104
+ elsif mode == "array_mode"
105
+ return @Output
106
+ end
107
+ end
108
+ def forwardPropagate(initialH=DFloat.zeros(1, @sz), initialC=DFloat.zeros(1, @sz))
109
+ @Output = DFloat.zeros(@Nodes, @sz)
110
+ # Start from front of network and work forwards
111
+ if initialH != nil && initialC != nil
112
+ @lstm_nodes[0,0].setHprev(initialH)
113
+ @lstm_nodes[0,0].setCprev(initialC)
114
+ else
115
+ @lstm_nodes[0,0].setHprev(DFloat.zeros(1, @sz))
116
+ @lstm_nodes[0,0].setCprev(DFloat.zeros(1, @sz))
117
+ end
118
+
119
+ node_input = DFloat[*[@Input[0, true].to_a]]
120
+ node_target = DFloat[*[@Target[0, true].to_a]]
121
+
122
+ @lstm_nodes[0,0].setXt(node_input)
123
+ @lstm_nodes[0,0].setYt(node_target)
124
+ @lstm_nodes[0,0].forwardPropagation()
125
+
126
+ @Output[0, true] = @lstm_nodes[0,0].getHt()
127
+ # Nodes 1 to end
128
+ for i in 1...@lstm_nodes.column_count()
129
+ # Indexing is working now
130
+ node_input = DFloat[*[@Input[i, true].to_a]]
131
+ node_target = DFloat[*[@Target[i, true].to_a]]
132
+ @lstm_nodes[0,i].setHprev(@lstm_nodes[0,i-1].getHt())
133
+ @lstm_nodes[0,i].setCprev(@lstm_nodes[0,i-1].getCt())
134
+ @lstm_nodes[0,i].setXt(node_input)
135
+ @lstm_nodes[0,i].setYt(node_target)
136
+ @lstm_nodes[0,i].forwardPropagation()
137
+ @Output[i, true] = @lstm_nodes[0,i].getHt()
138
+ end
139
+ end
140
+ def backwardPropagate()
141
+ i = 0
142
+ delta_h_init = DFloat.zeros(@lstm_nodes[0, i].getHt().shape())
143
+ delta_h_init = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt())
144
+ delta_c_init = DFloat.zeros(1, @sz)
145
+ @lstm_nodes[0,i].backwardPropagation(delta_h_init, delta_c_init)
146
+
147
+ i += 1
148
+ while i <= @lstm_nodes.column_count()-1 do
149
+ #puts "BP cell: " + i.to_s
150
+ delta_h = DFloat.zeros(@lstm_nodes[0, i].getHt().shape())
151
+ delta_h = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt())
152
+ delta_h += @lstm_nodes[0, i-1].getBottomDeltaHt()
153
+ delta_c = @lstm_nodes[0, i-1].getBottomDeltaCt()
154
+ @lstm_nodes[0,i].backwardPropagation(delta_h, delta_c)
155
+ i += 1
156
+ end
157
+ end
158
+ def applyWeightChange()
159
+ for i in 0...@lstm_nodes.column_count()
160
+ @lstm_nodes[0,i].applyWeightChange()
161
+ end
162
+ end
163
+ end
metadata ADDED
@@ -0,0 +1,46 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: basic-lstm
3
+ version: !ruby/object:Gem::Version
4
+ version: 1.0.0
5
+ platform: ruby
6
+ authors:
7
+ - Ryan May
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-06-30 00:00:00.000000000 Z
12
+ dependencies: []
13
+ description:
14
+ email: 19477774@student.curtin.edu.au
15
+ executables: []
16
+ extensions: []
17
+ extra_rdoc_files: []
18
+ files:
19
+ - lib/CELL.rb
20
+ - lib/DICTIONARY.rb
21
+ - lib/ENCODER.rb
22
+ - lib/NETWORK.rb
23
+ homepage: https://github.com/ryan-n-may/Ruby_LSTM
24
+ licenses:
25
+ - MIT
26
+ metadata: {}
27
+ post_install_message:
28
+ rdoc_options: []
29
+ require_paths:
30
+ - lib
31
+ required_ruby_version: !ruby/object:Gem::Requirement
32
+ requirements:
33
+ - - ">="
34
+ - !ruby/object:Gem::Version
35
+ version: '0'
36
+ required_rubygems_version: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ requirements: []
42
+ rubygems_version: 3.3.15
43
+ signing_key:
44
+ specification_version: 4
45
+ summary: A basic LSTM package
46
+ test_files: []