basic-lstm 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/lib/CELL.rb +261 -0
- data/lib/DICTIONARY.rb +123 -0
- data/lib/ENCODER.rb +132 -0
- data/lib/NETWORK.rb +163 -0
- metadata +46 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 77cfe659ce2d038f9f415fee83f260f7b7923933fa9275536f490066f26c96af
|
4
|
+
data.tar.gz: d5f35dd0a7625c278f9c17230e5fa7b31f8495405d22302786ef7f831877c692
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 2c854837f63551d21cde6a45d57855cda75a1b8b8e4db14d51ce540bbc24e127995793046385c56e9e05a9adaf12b7b2fce1fc05994d90007f39409c9056dd65
|
7
|
+
data.tar.gz: 57d4f463984171afd4a90d267d430b5226c649edcc08104634c4ab803d46cf8fd5c2428653c888ad3ab6d8724eba646bfec4c118c3bea2e0f64a240ff673db4f
|
data/lib/CELL.rb
ADDED
@@ -0,0 +1,261 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
=begin
|
3
|
+
██╗░░░░░░██████╗████████╗███╗░░░███╗
|
4
|
+
██║░░░░░██╔════╝╚══██╔══╝████╗░████║
|
5
|
+
██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
|
6
|
+
██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
|
7
|
+
███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
|
8
|
+
╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
|
9
|
+
Created: 14/09/2023.
|
10
|
+
Version: 1.0.0
|
11
|
+
Author: Ryan May
|
12
|
+
This is a text generator LSTM developed in Ruby.
|
13
|
+
In this project I gain more of an understanding of the training and creation of LSTM networks.
|
14
|
+
This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
|
15
|
+
Expansion on this project will be utilised in my Honours project.
|
16
|
+
=end
|
17
|
+
require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
|
18
|
+
require 'numo/narray'
|
19
|
+
include Numo
|
20
|
+
=begin
|
21
|
+
██╗░░░░░░██████╗████████╗███╗░░░███╗ ░█████╗░███████╗██╗░░░░░██╗░░░░░
|
22
|
+
██║░░░░░██╔════╝╚══██╔══╝████╗░████║ ██╔══██╗██╔════╝██║░░░░░██║░░░░░
|
23
|
+
██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║ ██║░░╚═╝█████╗░░██║░░░░░██║░░░░░
|
24
|
+
██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║ ██║░░██╗██╔══╝░░██║░░░░░██║░░░░░
|
25
|
+
███████╗██████╔╝░░░██║░░░██║░╚═╝░██║ ╚█████╔╝███████╗███████╗███████╗
|
26
|
+
╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝ ░╚════╝░╚══════╝╚══════╝╚══════╝
|
27
|
+
=end
|
28
|
+
class LSTM_CELL
|
29
|
+
# Constructor
|
30
|
+
# This initialisation function configures the LSTM Cell by
|
31
|
+
# constructing the required arrays (Numo::NArrays)
|
32
|
+
def init(alpha, sz, terminal_output=nil)
|
33
|
+
@sz = sz
|
34
|
+
@@Alpha = alpha
|
35
|
+
### Weight matrices with regard to inputs
|
36
|
+
@Wg = DFloat.new(1, @sz).rand
|
37
|
+
@Wi = DFloat.new(1, @sz).rand
|
38
|
+
@Wf = DFloat.new(1, @sz).rand
|
39
|
+
@Wo = DFloat.new(1, @sz).rand
|
40
|
+
### Weight matrices with regard to cell previous cell output
|
41
|
+
@Ug = DFloat.new(1, @sz).rand
|
42
|
+
@Ui = DFloat.new(1, @sz).rand
|
43
|
+
@Uf = DFloat.new(1, @sz).rand
|
44
|
+
@Uo = DFloat.new(1, @sz).rand
|
45
|
+
### Bias
|
46
|
+
@Bg = DFloat.new(1, @sz).rand
|
47
|
+
@Bi = DFloat.new(1, @sz).rand
|
48
|
+
@Bf = DFloat.new(1, @sz).rand
|
49
|
+
@Bo = DFloat.new(1, @sz).rand
|
50
|
+
### Cell input
|
51
|
+
@Xt = DFloat.zeros(1, @sz)
|
52
|
+
@Yt = DFloat.zeros(1, @sz)
|
53
|
+
### Cell state
|
54
|
+
@Hprev = DFloat.zeros(1, @sz) # Previous cell output
|
55
|
+
@Cprev = DFloat.zeros(1, @sz) # Previous cell state
|
56
|
+
@Ct = DFloat.zeros(1, @sz) # Cell state
|
57
|
+
@Ht = DFloat.zeros(1, @sz) # Cell output
|
58
|
+
### Internal Cell gates
|
59
|
+
@F = DFloat.new(1, @sz).rand
|
60
|
+
@I = DFloat.new(1, @sz).rand
|
61
|
+
@O = DFloat.new(1, @sz).rand
|
62
|
+
@G = DFloat.new(1, @sz).rand
|
63
|
+
### BPTT variables
|
64
|
+
### Change related to cell input
|
65
|
+
@dWi = DFloat.zeros(1, @sz)
|
66
|
+
@dWf = DFloat.zeros(1, @sz)
|
67
|
+
@dWo = DFloat.zeros(1, @sz)
|
68
|
+
@dWg = DFloat.zeros(1, @sz)
|
69
|
+
### Change related to cell state
|
70
|
+
@dUi = DFloat.zeros(1, @sz)
|
71
|
+
@dUf = DFloat.zeros(1, @sz)
|
72
|
+
@dUo = DFloat.zeros(1, @sz)
|
73
|
+
@dUg = DFloat.zeros(1, @sz)
|
74
|
+
### Change required to Bias
|
75
|
+
@dBi = DFloat.zeros(1, @sz)
|
76
|
+
@dBf = DFloat.zeros(1, @sz)
|
77
|
+
@dBo = DFloat.zeros(1, @sz)
|
78
|
+
@dBg = DFloat.zeros(1, @sz)
|
79
|
+
end
|
80
|
+
# Hadamard product
|
81
|
+
def hp(nArr1, nArr2)
|
82
|
+
nArr3 = DFloat.zeros(nArr1.shape())
|
83
|
+
dims = nArr1.shape()
|
84
|
+
for i in 0...(dims[0])
|
85
|
+
for j in 0...(dims[1])
|
86
|
+
nArr3[i,j] = (nArr1[i,j] * nArr2[i,j])
|
87
|
+
end
|
88
|
+
end
|
89
|
+
return nArr3
|
90
|
+
end
|
91
|
+
# Forward Propagation
|
92
|
+
def forwardPropagation()
|
93
|
+
### Calculating gate values
|
94
|
+
@Zf = @Wf.dot(@Xt) + @Uf.dot(@Hprev) + @Bf
|
95
|
+
@F = sigmoidVector(@Zf)
|
96
|
+
@Zi = @Wi.dot(@Xt) + @Ui.dot(@Hprev) + @Bi
|
97
|
+
@I = sigmoidVector(@Zi)
|
98
|
+
@Zo = @Wo.dot(@Xt) + @Uo.dot(@Hprev) + @Bo
|
99
|
+
@O = sigmoidVector(@Zo)
|
100
|
+
@Zg = @Wg.dot(@Xt) + @Ug.dot(@Hprev) + @Bg
|
101
|
+
@G = NMath.tanh(@Zg)
|
102
|
+
### Calculating cell states
|
103
|
+
@Ct = hp(@F, @Cprev) + hp(@I, @G)
|
104
|
+
@Ht = hp(@O, NMath.tanh(@Ct))
|
105
|
+
end
|
106
|
+
# Back propagation
|
107
|
+
def backwardPropagation(top_diff_h, top_diff_c)
|
108
|
+
dC = hp(@O, top_diff_h) + top_diff_c
|
109
|
+
dO = hp(@Ct, top_diff_h)
|
110
|
+
dI = hp(@G, dC)
|
111
|
+
dG = hp(@I, dC)
|
112
|
+
dF = hp(@Cprev, dC)
|
113
|
+
|
114
|
+
dIinput = hp(sigmoidDer(@I), dI)
|
115
|
+
dFinput = hp(sigmoidDer(@F), dF)
|
116
|
+
dOinput = hp(sigmoidDer(@O), dO)
|
117
|
+
dGinput = hp(tanhDer(@G), dG)
|
118
|
+
### Calculating change required to input weight matrices
|
119
|
+
@dWi -= hp(dIinput, @Xt)
|
120
|
+
@dWf -= hp(dFinput, @Xt)
|
121
|
+
@dWo -= hp(dOinput, @Xt)
|
122
|
+
@dWg -= hp(dGinput, @Xt)
|
123
|
+
#p @dWi.to_a
|
124
|
+
### Calculating change required to cell state weight matrices
|
125
|
+
@dUi += hp(dIinput, @Hprev)
|
126
|
+
@dUf += hp(dFinput, @Hprev)
|
127
|
+
@dUo += hp(dOinput, @Hprev)
|
128
|
+
@dUg += hp(dGinput, @Hprev)
|
129
|
+
### Calculating change required to bias vectors
|
130
|
+
@dBi += dIinput
|
131
|
+
@dBf += dFinput
|
132
|
+
@dBo += dOinput
|
133
|
+
@dBg += dGinput
|
134
|
+
### Calculating change required to HPrev
|
135
|
+
dHprev = DFloat.zeros(1, @sz)
|
136
|
+
dHprev += @Wi.transpose.dot(dIinput)
|
137
|
+
dHprev += @Wf.transpose.dot(dFinput)
|
138
|
+
dHprev += @Wo.transpose.dot(dOinput)
|
139
|
+
dHprev += @Wg.transpose.dot(dGinput)
|
140
|
+
### Calculating and returning bottoms
|
141
|
+
@bottom_diff_c = hp(dC, @F)
|
142
|
+
@bottom_diff_h = dHprev
|
143
|
+
end
|
144
|
+
def applyWeightChange()
|
145
|
+
#puts ">> " + @dWg[0,0].to_s
|
146
|
+
### Apply input weight changes
|
147
|
+
@Wi += @@Alpha * @dWi #@@Alpha
|
148
|
+
@Wf += @@Alpha * @dWf
|
149
|
+
@Wo += @@Alpha * @dWo
|
150
|
+
@Wg += @@Alpha * @dWg
|
151
|
+
### Apply Cell output weight changes
|
152
|
+
@Ui += @@Alpha * @dUi
|
153
|
+
@Uf += @@Alpha * @dUf
|
154
|
+
@Uo += @@Alpha * @dUo
|
155
|
+
@Ug += @@Alpha * @dUg
|
156
|
+
### Apply bias changes
|
157
|
+
@Bi += @@Alpha * @dBi
|
158
|
+
@Bf += @@Alpha * @dBf
|
159
|
+
@Bo += @@Alpha * @dBo
|
160
|
+
@Bg += @@Alpha * @dBg
|
161
|
+
#puts "<< " + @Wg[0,0].to_s
|
162
|
+
### Change related to cell input
|
163
|
+
@dWi = DFloat.zeros(1, @sz)
|
164
|
+
@dWf = DFloat.zeros(1, @sz)
|
165
|
+
@dWo = DFloat.zeros(1, @sz)
|
166
|
+
@dWg = DFloat.zeros(1, @sz)
|
167
|
+
### Change related to cell state
|
168
|
+
@dUi = DFloat.zeros(1, @sz)
|
169
|
+
@dUf = DFloat.zeros(1, @sz)
|
170
|
+
@dUo = DFloat.zeros(1, @sz)
|
171
|
+
@dUg = DFloat.zeros(1, @sz)
|
172
|
+
### Change required to Bias
|
173
|
+
@dBi = DFloat.zeros(1, @sz)
|
174
|
+
@dBf = DFloat.zeros(1, @sz)
|
175
|
+
@dBo = DFloat.zeros(1, @sz)
|
176
|
+
@dBg = DFloat.zeros(1, @sz)
|
177
|
+
end
|
178
|
+
# Transfer functions
|
179
|
+
def tanhDer(v)
|
180
|
+
#input vector must be a horisontal vector
|
181
|
+
output_vector = DFloat.zeros(v.shape[0], v.shape[1])
|
182
|
+
for i in 0...v.shape[0]
|
183
|
+
for j in 0...v.shape[1]
|
184
|
+
output_vector[i,j] = 1 - (v[i,j]**2)
|
185
|
+
end
|
186
|
+
end
|
187
|
+
return output_vector
|
188
|
+
end
|
189
|
+
def sigmoidDer(v)
|
190
|
+
#input vector must be a horisontal vector
|
191
|
+
output_vector = DFloat.zeros(v.shape[0], v.shape[1])
|
192
|
+
for i in 0...v.shape[0]
|
193
|
+
for j in 0...v.shape[1]
|
194
|
+
output_vector[i,j] = v[i,j] * (1 - v[i,j])
|
195
|
+
end
|
196
|
+
end
|
197
|
+
return output_vector
|
198
|
+
end
|
199
|
+
def sigmoidVector(v)
|
200
|
+
#input vector must be a horisontal vector
|
201
|
+
output_vector = DFloat.zeros(v.shape[0], v.shape[1])
|
202
|
+
for i in 0...v.shape[0]
|
203
|
+
for j in 0...v.shape[1]
|
204
|
+
output_vector[i,j] = sigmoid(v[i,j])
|
205
|
+
end
|
206
|
+
end
|
207
|
+
return output_vector
|
208
|
+
end
|
209
|
+
def sigmoid(value_)
|
210
|
+
output_value = 1 / (1 + Math.exp(-1 * value_))
|
211
|
+
return output_value
|
212
|
+
end
|
213
|
+
# Getters and Setters
|
214
|
+
def setXt(xt)
|
215
|
+
if xt.shape()[0] != 1
|
216
|
+
puts "Size of xt: " + xt.shape().to_s
|
217
|
+
raise "Setting @Xt of LSTM cell raised an incorrect dimension error"
|
218
|
+
return
|
219
|
+
end
|
220
|
+
@Xt = xt
|
221
|
+
end
|
222
|
+
def setYt(yt)
|
223
|
+
if yt.shape()[0] != 1
|
224
|
+
puts "Size of yt: " + yt.shape().to_s
|
225
|
+
raise "Setting @Yt of LSTM cell raised an incorrect dimension error"
|
226
|
+
return
|
227
|
+
end
|
228
|
+
@Yt = yt
|
229
|
+
end
|
230
|
+
def setHprev(hp)
|
231
|
+
if hp.shape()[0] != 1
|
232
|
+
puts "Size of hp: " + hp.shape().to_s
|
233
|
+
raise "Setting @Hprev of LSTM cell raised an incorrect dimension error"
|
234
|
+
return
|
235
|
+
end
|
236
|
+
@Hprev = hp
|
237
|
+
end
|
238
|
+
def setCprev(cp)
|
239
|
+
@Cprev = cp
|
240
|
+
if cp.shape()[0] != 1
|
241
|
+
puts "Size of cp: " + cp.shape().to_s
|
242
|
+
raise "Setting @Cprev of LSTM cell raised an incorrect dimension error"
|
243
|
+
return
|
244
|
+
end
|
245
|
+
end
|
246
|
+
def getBottomDeltaHt()
|
247
|
+
return @bottom_diff_h
|
248
|
+
end
|
249
|
+
def getBottomDeltaCt()
|
250
|
+
return @bottom_diff_c
|
251
|
+
end
|
252
|
+
def getYt()
|
253
|
+
return @Yt
|
254
|
+
end
|
255
|
+
def getHt()
|
256
|
+
return @Ht
|
257
|
+
end
|
258
|
+
def getCt
|
259
|
+
return @Ct
|
260
|
+
end
|
261
|
+
end
|
data/lib/DICTIONARY.rb
ADDED
@@ -0,0 +1,123 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
=begin
|
3
|
+
Created: 14/09/2023.
|
4
|
+
Version: 1.0.0
|
5
|
+
=end
|
6
|
+
require 'numo/narray'
|
7
|
+
include Numo
|
8
|
+
=begin
|
9
|
+
Dictionary
|
10
|
+
=end
|
11
|
+
class DICTIONARY < ENCODER
|
12
|
+
def init(string, x_dim)
|
13
|
+
@x_dim = x_dim
|
14
|
+
@frequencyHash = wordFrequency(string)
|
15
|
+
@stringToEncoding, @encodingToString = hotEncodeVocabulary(@frequencyHash, x_dim)
|
16
|
+
end
|
17
|
+
def getFrequencyHash()
|
18
|
+
return @frequencyHash
|
19
|
+
end
|
20
|
+
def getEncodingHashs()
|
21
|
+
return @stringToEncoding, @encodingToString
|
22
|
+
end
|
23
|
+
def wordFrequency(words)
|
24
|
+
if words.is_a?(String)
|
25
|
+
words = words.split
|
26
|
+
end
|
27
|
+
frequency = {}
|
28
|
+
entries = 0
|
29
|
+
for i in 0...words.length()
|
30
|
+
if frequency[words[i]] == nil
|
31
|
+
frequency[words[i]] = 1
|
32
|
+
entries += 1
|
33
|
+
else
|
34
|
+
frequency[words[i]] = frequency[words[i]] + 1
|
35
|
+
end
|
36
|
+
end
|
37
|
+
return frequency.sort_by {|k, v| v}.reverse
|
38
|
+
end
|
39
|
+
def hotEncodeVocabulary(hash, maxEntries)
|
40
|
+
stringToEncoding = {}
|
41
|
+
encodingToString = {}
|
42
|
+
binary = 0b1
|
43
|
+
hash.each do |key, value|
|
44
|
+
stringToEncoding["#{key}"] = binary
|
45
|
+
encodingToString[binary] = "#{key}"
|
46
|
+
binary = binary << 1
|
47
|
+
if (binary) == (0b1 << maxEntries)
|
48
|
+
break;
|
49
|
+
end
|
50
|
+
end
|
51
|
+
return stringToEncoding, encodingToString
|
52
|
+
end
|
53
|
+
def readFileDataArray(fileDataArray, fileIndex, start, input_len, predict_len)
|
54
|
+
fileWordArray = (fileDataArray[fileIndex]).split
|
55
|
+
return fileWordArray[start, input_len], fileWordArray[(start+predict_len), (input_len)]
|
56
|
+
end
|
57
|
+
def encodeArray(wordArray, hash=@stringToEncoding)
|
58
|
+
if wordArray.instance_of? String
|
59
|
+
wordArray = wordArray.split()
|
60
|
+
end
|
61
|
+
array = Int32.zeros(wordArray.length(), @x_dim)
|
62
|
+
for i in 0...wordArray.length()
|
63
|
+
word = wordArray[i]
|
64
|
+
encoding = hash[word]
|
65
|
+
if encoding != nil
|
66
|
+
binaryArray = binaryToArray(encoding, @x_dim)
|
67
|
+
array[i, true] = binaryArray
|
68
|
+
end
|
69
|
+
end
|
70
|
+
return array
|
71
|
+
end
|
72
|
+
def binaryToArray(binary, x_dim)
|
73
|
+
binaryFormatted = ((0b1 << (x_dim+1) ^ binary*2)).to_s(2)[1, x_dim]
|
74
|
+
binaryArray = DFloat.zeros(1, x_dim)
|
75
|
+
binaryArray = binaryFormatted.split("").map(&:to_i)
|
76
|
+
return binaryArray
|
77
|
+
end
|
78
|
+
def arrayToBinary(array)
|
79
|
+
binaryString = array.to_a.join("")
|
80
|
+
binary = binaryString.to_i(2)
|
81
|
+
return binary
|
82
|
+
end
|
83
|
+
def decodeArray(encodedArray, hash=@encodingToString)
|
84
|
+
output = ""
|
85
|
+
for i in 0...encodedArray.shape()[0]
|
86
|
+
decodedWord = hash[arrayToBinary(encodedArray[i, true])]
|
87
|
+
if decodedWord != nil
|
88
|
+
output += decodedWord + " "
|
89
|
+
else
|
90
|
+
output += "@ "
|
91
|
+
end
|
92
|
+
end
|
93
|
+
return output
|
94
|
+
end
|
95
|
+
def decodeArrayByMaximum(encodedArray, hash=@encodingToString)
|
96
|
+
output = ""
|
97
|
+
for i in 0..encodedArray.shape()[0]-1
|
98
|
+
if encodedArray[i, true].to_a.max != 0
|
99
|
+
max_element = encodedArray[i, true].to_a.each_with_index.max[1]
|
100
|
+
encodedArrayBW = Int32.zeros(encodedArray[i, true].shape())
|
101
|
+
encodedArrayBW[max_element] = 1
|
102
|
+
decodedWord = hash[arrayToBinary(encodedArrayBW)]
|
103
|
+
if decodedWord != nil
|
104
|
+
output += decodedWord + " "
|
105
|
+
else
|
106
|
+
output += "@ "
|
107
|
+
end
|
108
|
+
else
|
109
|
+
output += "$ "
|
110
|
+
end
|
111
|
+
end
|
112
|
+
return output
|
113
|
+
end
|
114
|
+
def viewHash(hash)
|
115
|
+
puts "key -> value"
|
116
|
+
entries = 0
|
117
|
+
hash.each do |key, value|
|
118
|
+
puts key.to_s + " -> " + value.to_s
|
119
|
+
entries += 1
|
120
|
+
end
|
121
|
+
puts "Entires in hash: " + entries.to_s
|
122
|
+
end
|
123
|
+
end
|
data/lib/ENCODER.rb
ADDED
@@ -0,0 +1,132 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
=begin
|
3
|
+
██╗░░░░░░██████╗████████╗███╗░░░███╗
|
4
|
+
██║░░░░░██╔════╝╚══██╔══╝████╗░████║
|
5
|
+
██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
|
6
|
+
██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
|
7
|
+
███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
|
8
|
+
╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
|
9
|
+
Created: 14/09/2023.
|
10
|
+
Version: 1.0.0
|
11
|
+
Author: Ryan May
|
12
|
+
This is a text generator LSTM developed in Ruby.
|
13
|
+
In this project I gain more of an understanding of the training and creation of LSTM networks.
|
14
|
+
This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
|
15
|
+
Expansion on this project will be utilised in my Honours project.
|
16
|
+
=end
|
17
|
+
require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
|
18
|
+
require 'numo/narray'
|
19
|
+
include Numo
|
20
|
+
=begin
|
21
|
+
███████╗███╗░░██╗░█████╗░░█████╗░██████╗░███████╗██████╗░
|
22
|
+
██╔════╝████╗░██║██╔══██╗██╔══██╗██╔══██╗██╔════╝██╔══██╗
|
23
|
+
█████╗░░██╔██╗██║██║░░╚═╝██║░░██║██║░░██║█████╗░░██████╔╝
|
24
|
+
██╔══╝░░██║╚████║██║░░██╗██║░░██║██║░░██║██╔══╝░░██╔══██╗
|
25
|
+
███████╗██║░╚███║╚█████╔╝╚█████╔╝██████╔╝███████╗██║░░██║
|
26
|
+
╚══════╝╚═╝░░╚══╝░╚════╝░░╚════╝░╚═════╝░╚══════╝╚═╝░░╚═╝
|
27
|
+
The encoder class is used to translate input characters and sentances into one-hot encoded vectors.
|
28
|
+
Additionally, the encoder class contains functions to handle file reading, regex filtering, and conversions
|
29
|
+
between 'Matrix' and 'NArray' types.
|
30
|
+
=end
|
31
|
+
class ENCODER
|
32
|
+
def init(charmatrix = nil)
|
33
|
+
if charmatrix == nil
|
34
|
+
@Length = 56
|
35
|
+
@SelectionMatrix = Matrix.build(1,56) {0} # 1 row 32 columns
|
36
|
+
@CharMatrix = Array['.',',',"\s",'!',
|
37
|
+
'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
|
38
|
+
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']
|
39
|
+
else
|
40
|
+
@Length = charmatrix.length()
|
41
|
+
@SelectionMatrix = Matrix.build(1, charmatrix.length()) { 0 }
|
42
|
+
@CharMatrix = charmatrix
|
43
|
+
end
|
44
|
+
end
|
45
|
+
def matrixToNArray(matrix)
|
46
|
+
nArray = DFloat[*matrix.to_a]
|
47
|
+
return nArray
|
48
|
+
end
|
49
|
+
def nArrayToMatrix(nArray)
|
50
|
+
#puts nArray.to_a.map(&:inspect)
|
51
|
+
matrix = Matrix[*nArray.to_a]
|
52
|
+
#puts matrix.to_a.map(&:inspect)
|
53
|
+
return matrix
|
54
|
+
end
|
55
|
+
def listFilesInDirectory(path)
|
56
|
+
fileNameArray = Dir.entries(path).reject {|f| File.directory?(f) || f[0].include?('.')}
|
57
|
+
return fileNameArray
|
58
|
+
end
|
59
|
+
def readPlaintextFilesInDirectory(path)
|
60
|
+
fileDataArray = Array.new(0) { "" }
|
61
|
+
# Get names of all files in a folder.
|
62
|
+
fileNameArray = Dir.entries(path).reject {|f| File.directory?(f) || f[0].include?('.')}
|
63
|
+
# itterate through all files in Array
|
64
|
+
for index in 0...fileNameArray.length()
|
65
|
+
puts "reading file " + path + "/" + fileNameArray[index]
|
66
|
+
fileData = readPlaintextfile(path + "/" + fileNameArray[index])
|
67
|
+
fileDataArray = fileDataArray << fileData
|
68
|
+
end
|
69
|
+
return fileDataArray
|
70
|
+
end
|
71
|
+
def readFileDataArray(fileDataArray, index, start, input_len, predict_len)
|
72
|
+
combination = hotEncodeSentance(filter(fileDataArray[index])[start,predict_len+input_len])
|
73
|
+
return combination
|
74
|
+
end
|
75
|
+
def readPlaintextfile(fileName)
|
76
|
+
file = File.open(fileName)
|
77
|
+
file_data = file.read
|
78
|
+
file.close
|
79
|
+
return file_data
|
80
|
+
end
|
81
|
+
def filter(sentance, regex=nil)
|
82
|
+
if regex == nil
|
83
|
+
sentance = sentance.gsub(/[^A-Za-z\., ]/, '')
|
84
|
+
else
|
85
|
+
sentance = sentance.gsub(regex, '')
|
86
|
+
end
|
87
|
+
return sentance
|
88
|
+
end
|
89
|
+
def hotEncodeSentance(sentance)
|
90
|
+
letters = sentance.split(//)
|
91
|
+
matrix = Matrix.build(0, 0){0}
|
92
|
+
for index in 0...letters.length()
|
93
|
+
charvector = hotEncodeCharacter(letters[index])
|
94
|
+
if charvector != nil
|
95
|
+
matrix = Matrix.rows(matrix.to_a << charvector.to_a)
|
96
|
+
end
|
97
|
+
end
|
98
|
+
return matrix
|
99
|
+
end
|
100
|
+
def hotDecodeSentance(matrix)
|
101
|
+
sentance = ""
|
102
|
+
for index in 0...matrix.row_count()
|
103
|
+
charvector = matrix.row(index)
|
104
|
+
char = hotDecodeCharacter(charvector)
|
105
|
+
sentance = sentance + char
|
106
|
+
end
|
107
|
+
return sentance
|
108
|
+
end
|
109
|
+
def hotEncodeCharacter(char)
|
110
|
+
# reset selection matrix
|
111
|
+
@SelectionMatrix = Matrix.build(1,@Length) {0}
|
112
|
+
index = @CharMatrix.index char
|
113
|
+
@SelectionMatrix[0,index] = 1
|
114
|
+
return @SelectionMatrix.row(0)
|
115
|
+
end
|
116
|
+
def hotDecodeCharacter(vector)
|
117
|
+
# convert vector to array so we can search for the up bit
|
118
|
+
vectorArray = vector.to_a
|
119
|
+
index = vectorArray.each_with_index.max[1]
|
120
|
+
if index != nil
|
121
|
+
return @CharMatrix[index]
|
122
|
+
else
|
123
|
+
return '@'
|
124
|
+
end
|
125
|
+
end
|
126
|
+
def stringDifferencePercent(a, b)
|
127
|
+
longer = [a.size, b.size].max
|
128
|
+
same = a.each_char.zip(b.each_char).count { |a,b| a == b }
|
129
|
+
similarity = (longer - same) / a.size.to_f
|
130
|
+
return similarity
|
131
|
+
end
|
132
|
+
end
|
data/lib/NETWORK.rb
ADDED
@@ -0,0 +1,163 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
=begin
|
3
|
+
██╗░░░░░░██████╗████████╗███╗░░░███╗
|
4
|
+
██║░░░░░██╔════╝╚══██╔══╝████╗░████║
|
5
|
+
██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║
|
6
|
+
██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║
|
7
|
+
███████╗██████╔╝░░░██║░░░██║░╚═╝░██║
|
8
|
+
╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝
|
9
|
+
Created: 14/09/2023.
|
10
|
+
Version: 1.0.0
|
11
|
+
Author: Ryan May
|
12
|
+
This is a text generator LSTM developed in Ruby.
|
13
|
+
In this project I gain more of an understanding of the training and creation of LSTM networks.
|
14
|
+
This LSTM network utilises the Numo::Narray library to calculate cell states and the cell network.
|
15
|
+
Expansion on this project will be utilised in my Honours project.
|
16
|
+
=end
|
17
|
+
require_relative 'ENCODER'
|
18
|
+
require_relative 'CELL'
|
19
|
+
require_relative 'DICTIONARY'
|
20
|
+
require 'matrix' # https://www.rubyguides.com/2019/01/ruby-matrix/
|
21
|
+
require 'numo/narray'
|
22
|
+
include Numo
|
23
|
+
=begin
|
24
|
+
██╗░░░░░░██████╗████████╗███╗░░░███╗ ███╗░░██╗███████╗████████╗░██╗░░░░░░░██╗░█████╗░██████╗░██╗░░██╗
|
25
|
+
██║░░░░░██╔════╝╚══██╔══╝████╗░████║ ████╗░██║██╔════╝╚══██╔══╝░██║░░██╗░░██║██╔══██╗██╔══██╗██║░██╔╝
|
26
|
+
██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║ ██╔██╗██║█████╗░░░░░██║░░░░╚██╗████╗██╔╝██║░░██║██████╔╝█████═╝░
|
27
|
+
██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║ ██║╚████║██╔══╝░░░░░██║░░░░░████╔═████║░██║░░██║██╔══██╗██╔═██╗░
|
28
|
+
███████╗██████╔╝░░░██║░░░██║░╚═╝░██║ ██║░╚███║███████╗░░░██║░░░░░╚██╔╝░╚██╔╝░╚█████╔╝██║░░██║██║░╚██╗
|
29
|
+
╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝ ╚═╝░░╚══╝╚══════╝░░░╚═╝░░░░░░╚═╝░░░╚═╝░░░╚════╝░╚═╝░░╚═╝╚═╝░░╚═╝
|
30
|
+
This class handles the single-height LSTM network. The LSTM network comprises of an array of LSTM cell objects,
|
31
|
+
an input matrix, and a target matrix.
|
32
|
+
Forward propagation for the network, and backward network propagation is implemented here.
|
33
|
+
This is seperate from the cellular-level forward and back propagation in the LSTM_CELL class.
|
34
|
+
=end
|
35
|
+
class LSTM_NETWORK
|
36
|
+
def init(nodes, x_dim, alpha, terminal_output=nil)
|
37
|
+
@dict = DICTIONARY. new
|
38
|
+
@encoder = ENCODER. new
|
39
|
+
@encoder.init()
|
40
|
+
@sz = x_dim
|
41
|
+
@@Alpha = alpha
|
42
|
+
@Nodes = nodes
|
43
|
+
### Creating LSTM network (matrix of nodes)
|
44
|
+
@lstm_nodes = Matrix.build(1, @Nodes) { LSTM_CELL }
|
45
|
+
for i in 0...@lstm_nodes.column_count()
|
46
|
+
@lstm_nodes[0,i] = LSTM_CELL. new
|
47
|
+
@lstm_nodes[0,i].init(@@Alpha, @sz, terminal_output)
|
48
|
+
end
|
49
|
+
## Input and target
|
50
|
+
@Input = DFloat.zeros(nodes, @sz)
|
51
|
+
@Target = DFloat.zeros(nodes, @sz)
|
52
|
+
end
|
53
|
+
def getLSTMNodes()
|
54
|
+
return @lstm_nodes
|
55
|
+
end
|
56
|
+
def setDictionary(dictionary)
|
57
|
+
@dict = dictionary
|
58
|
+
end
|
59
|
+
def setTarget(target, mode=nil)
|
60
|
+
if mode == "encoded"
|
61
|
+
@Target = target
|
62
|
+
elsif target.instance_of? String
|
63
|
+
@Target = @encoder.matrixToNArray(@encoder.hotEncodeSentance(target))
|
64
|
+
elsif target.instance_of? Array
|
65
|
+
@Target = @dict.encodeArray(target)
|
66
|
+
else
|
67
|
+
raise "Target input is not of type Array or String"
|
68
|
+
end
|
69
|
+
end
|
70
|
+
def setInput(input, mode=nil)
|
71
|
+
if mode == "encoded"
|
72
|
+
@Input = input
|
73
|
+
elsif input.instance_of? String
|
74
|
+
@Input = @encoder.matrixToNArray(@encoder.hotEncodeSentance(input))
|
75
|
+
elsif input.instance_of? Array
|
76
|
+
@Input = @dict.encodeArray(input)
|
77
|
+
else
|
78
|
+
raise "Target input is not of type Array or String"
|
79
|
+
end
|
80
|
+
end
|
81
|
+
def getTarget(mode="word_mode")
|
82
|
+
if mode == "word_mode"
|
83
|
+
return @dict.decodeArray(@Target)
|
84
|
+
elsif mode == "char_mode"
|
85
|
+
return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Target))
|
86
|
+
elsif mode == "array_mode"
|
87
|
+
return @Target
|
88
|
+
end
|
89
|
+
end
|
90
|
+
def getInput(mode="word_mode")
|
91
|
+
if mode == "word_mode"
|
92
|
+
return @dict.decodeArray(@Input)
|
93
|
+
elsif mode == "char_mode"
|
94
|
+
return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Input))
|
95
|
+
elsif mode == "array_mode"
|
96
|
+
return @Input
|
97
|
+
end
|
98
|
+
end
|
99
|
+
def getOutput(mode="word_mode")
|
100
|
+
if mode == "word_mode"
|
101
|
+
return @dict.decodeArrayByMaximum(@Output)
|
102
|
+
elsif mode == "char_mode"
|
103
|
+
return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Output))
|
104
|
+
elsif mode == "array_mode"
|
105
|
+
return @Output
|
106
|
+
end
|
107
|
+
end
|
108
|
+
def forwardPropagate(initialH=DFloat.zeros(1, @sz), initialC=DFloat.zeros(1, @sz))
|
109
|
+
@Output = DFloat.zeros(@Nodes, @sz)
|
110
|
+
# Start from front of network and work forwards
|
111
|
+
if initialH != nil && initialC != nil
|
112
|
+
@lstm_nodes[0,0].setHprev(initialH)
|
113
|
+
@lstm_nodes[0,0].setCprev(initialC)
|
114
|
+
else
|
115
|
+
@lstm_nodes[0,0].setHprev(DFloat.zeros(1, @sz))
|
116
|
+
@lstm_nodes[0,0].setCprev(DFloat.zeros(1, @sz))
|
117
|
+
end
|
118
|
+
|
119
|
+
node_input = DFloat[*[@Input[0, true].to_a]]
|
120
|
+
node_target = DFloat[*[@Target[0, true].to_a]]
|
121
|
+
|
122
|
+
@lstm_nodes[0,0].setXt(node_input)
|
123
|
+
@lstm_nodes[0,0].setYt(node_target)
|
124
|
+
@lstm_nodes[0,0].forwardPropagation()
|
125
|
+
|
126
|
+
@Output[0, true] = @lstm_nodes[0,0].getHt()
|
127
|
+
# Nodes 1 to end
|
128
|
+
for i in 1...@lstm_nodes.column_count()
|
129
|
+
# Indexing is working now
|
130
|
+
node_input = DFloat[*[@Input[i, true].to_a]]
|
131
|
+
node_target = DFloat[*[@Target[i, true].to_a]]
|
132
|
+
@lstm_nodes[0,i].setHprev(@lstm_nodes[0,i-1].getHt())
|
133
|
+
@lstm_nodes[0,i].setCprev(@lstm_nodes[0,i-1].getCt())
|
134
|
+
@lstm_nodes[0,i].setXt(node_input)
|
135
|
+
@lstm_nodes[0,i].setYt(node_target)
|
136
|
+
@lstm_nodes[0,i].forwardPropagation()
|
137
|
+
@Output[i, true] = @lstm_nodes[0,i].getHt()
|
138
|
+
end
|
139
|
+
end
|
140
|
+
def backwardPropagate()
|
141
|
+
i = 0
|
142
|
+
delta_h_init = DFloat.zeros(@lstm_nodes[0, i].getHt().shape())
|
143
|
+
delta_h_init = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt())
|
144
|
+
delta_c_init = DFloat.zeros(1, @sz)
|
145
|
+
@lstm_nodes[0,i].backwardPropagation(delta_h_init, delta_c_init)
|
146
|
+
|
147
|
+
i += 1
|
148
|
+
while i <= @lstm_nodes.column_count()-1 do
|
149
|
+
#puts "BP cell: " + i.to_s
|
150
|
+
delta_h = DFloat.zeros(@lstm_nodes[0, i].getHt().shape())
|
151
|
+
delta_h = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt())
|
152
|
+
delta_h += @lstm_nodes[0, i-1].getBottomDeltaHt()
|
153
|
+
delta_c = @lstm_nodes[0, i-1].getBottomDeltaCt()
|
154
|
+
@lstm_nodes[0,i].backwardPropagation(delta_h, delta_c)
|
155
|
+
i += 1
|
156
|
+
end
|
157
|
+
end
|
158
|
+
def applyWeightChange()
|
159
|
+
for i in 0...@lstm_nodes.column_count()
|
160
|
+
@lstm_nodes[0,i].applyWeightChange()
|
161
|
+
end
|
162
|
+
end
|
163
|
+
end
|
metadata
ADDED
@@ -0,0 +1,46 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: basic-lstm
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 1.0.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Ryan May
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-06-30 00:00:00.000000000 Z
|
12
|
+
dependencies: []
|
13
|
+
description:
|
14
|
+
email: 19477774@student.curtin.edu.au
|
15
|
+
executables: []
|
16
|
+
extensions: []
|
17
|
+
extra_rdoc_files: []
|
18
|
+
files:
|
19
|
+
- lib/CELL.rb
|
20
|
+
- lib/DICTIONARY.rb
|
21
|
+
- lib/ENCODER.rb
|
22
|
+
- lib/NETWORK.rb
|
23
|
+
homepage: https://github.com/ryan-n-may/Ruby_LSTM
|
24
|
+
licenses:
|
25
|
+
- MIT
|
26
|
+
metadata: {}
|
27
|
+
post_install_message:
|
28
|
+
rdoc_options: []
|
29
|
+
require_paths:
|
30
|
+
- lib
|
31
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
32
|
+
requirements:
|
33
|
+
- - ">="
|
34
|
+
- !ruby/object:Gem::Version
|
35
|
+
version: '0'
|
36
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
requirements: []
|
42
|
+
rubygems_version: 3.3.15
|
43
|
+
signing_key:
|
44
|
+
specification_version: 4
|
45
|
+
summary: A basic LSTM package
|
46
|
+
test_files: []
|