daimond 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CONTRIBUTIONG.md +160 -0
- data/README.ja.md +115 -0
- data/README.md +115 -0
- data/README.ru.md +116 -0
- data/ext/daimond_rust/Cargo.lock +353 -0
- data/ext/daimond_rust/Cargo.toml +13 -0
- data/ext/daimond_rust/build.rs +3 -0
- data/ext/daimond_rust/src/lib.rs +103 -0
- data/lib/daimond/autograd.rb +0 -0
- data/lib/daimond/data/data_loader.rb +41 -0
- data/lib/daimond/data/mnist.rb +56 -0
- data/lib/daimond/loss/cross_entropy.rb +45 -0
- data/lib/daimond/loss/mse.rb +0 -0
- data/lib/daimond/nn/conv2d.rb +117 -0
- data/lib/daimond/nn/conv2d_rust.rb +52 -0
- data/lib/daimond/nn/flatten.rb +29 -0
- data/lib/daimond/nn/functional.rb +0 -0
- data/lib/daimond/nn/linear.rb +22 -0
- data/lib/daimond/nn/max_pool2d.rb +69 -0
- data/lib/daimond/nn/max_pool2d_rust.rb +33 -0
- data/lib/daimond/nn/module.rb +60 -0
- data/lib/daimond/optim/adam.rb +41 -0
- data/lib/daimond/optim/sgd.rb +25 -0
- data/lib/daimond/rust/daimond_rust.bundle +0 -0
- data/lib/daimond/rust_backend.rb +23 -0
- data/lib/daimond/rust_bridge.rb +63 -0
- data/lib/daimond/tensor.rb +241 -0
- data/lib/daimond/utils/training_logger.rb +111 -0
- data/lib/daimond/version.rb +3 -0
- data/lib/daimond.rb +40 -0
- metadata +134 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
# This file is automatically @generated by Cargo.
|
|
2
|
+
# It is not intended for manual editing.
|
|
3
|
+
version = 4
|
|
4
|
+
|
|
5
|
+
[[package]]
|
|
6
|
+
name = "aho-corasick"
|
|
7
|
+
version = "1.1.4"
|
|
8
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
9
|
+
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"memchr",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[[package]]
|
|
15
|
+
name = "autocfg"
|
|
16
|
+
version = "1.5.0"
|
|
17
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
18
|
+
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
|
19
|
+
|
|
20
|
+
[[package]]
|
|
21
|
+
name = "bindgen"
|
|
22
|
+
version = "0.69.5"
|
|
23
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
24
|
+
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
|
25
|
+
dependencies = [
|
|
26
|
+
"bitflags",
|
|
27
|
+
"cexpr",
|
|
28
|
+
"clang-sys",
|
|
29
|
+
"itertools",
|
|
30
|
+
"lazy_static",
|
|
31
|
+
"lazycell",
|
|
32
|
+
"proc-macro2",
|
|
33
|
+
"quote",
|
|
34
|
+
"regex",
|
|
35
|
+
"rustc-hash",
|
|
36
|
+
"shlex",
|
|
37
|
+
"syn",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
[[package]]
|
|
41
|
+
name = "bitflags"
|
|
42
|
+
version = "2.10.0"
|
|
43
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
44
|
+
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
|
|
45
|
+
|
|
46
|
+
[[package]]
|
|
47
|
+
name = "cexpr"
|
|
48
|
+
version = "0.6.0"
|
|
49
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
50
|
+
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
|
51
|
+
dependencies = [
|
|
52
|
+
"nom",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
[[package]]
|
|
56
|
+
name = "cfg-if"
|
|
57
|
+
version = "1.0.4"
|
|
58
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
59
|
+
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
|
60
|
+
|
|
61
|
+
[[package]]
|
|
62
|
+
name = "clang-sys"
|
|
63
|
+
version = "1.8.1"
|
|
64
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
65
|
+
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
|
66
|
+
dependencies = [
|
|
67
|
+
"glob",
|
|
68
|
+
"libc",
|
|
69
|
+
"libloading",
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
[[package]]
|
|
73
|
+
name = "daimond_rust"
|
|
74
|
+
version = "0.1.0"
|
|
75
|
+
dependencies = [
|
|
76
|
+
"magnus",
|
|
77
|
+
"ndarray",
|
|
78
|
+
"rb-sys",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
[[package]]
|
|
82
|
+
name = "either"
|
|
83
|
+
version = "1.15.0"
|
|
84
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
85
|
+
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
|
86
|
+
|
|
87
|
+
[[package]]
|
|
88
|
+
name = "glob"
|
|
89
|
+
version = "0.3.3"
|
|
90
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
91
|
+
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
|
92
|
+
|
|
93
|
+
[[package]]
|
|
94
|
+
name = "itertools"
|
|
95
|
+
version = "0.12.1"
|
|
96
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
97
|
+
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
|
98
|
+
dependencies = [
|
|
99
|
+
"either",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
[[package]]
|
|
103
|
+
name = "lazy_static"
|
|
104
|
+
version = "1.5.0"
|
|
105
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
106
|
+
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
|
107
|
+
|
|
108
|
+
[[package]]
|
|
109
|
+
name = "lazycell"
|
|
110
|
+
version = "1.3.0"
|
|
111
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
112
|
+
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
|
113
|
+
|
|
114
|
+
[[package]]
|
|
115
|
+
name = "libc"
|
|
116
|
+
version = "0.2.180"
|
|
117
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
118
|
+
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
|
|
119
|
+
|
|
120
|
+
[[package]]
|
|
121
|
+
name = "libloading"
|
|
122
|
+
version = "0.8.9"
|
|
123
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
124
|
+
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
|
125
|
+
dependencies = [
|
|
126
|
+
"cfg-if",
|
|
127
|
+
"windows-link",
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
[[package]]
|
|
131
|
+
name = "magnus"
|
|
132
|
+
version = "0.6.4"
|
|
133
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
134
|
+
checksum = "b1597ef40aa8c36be098249e82c9a20cf7199278ac1c1a1a995eeead6a184479"
|
|
135
|
+
dependencies = [
|
|
136
|
+
"magnus-macros",
|
|
137
|
+
"rb-sys",
|
|
138
|
+
"rb-sys-env",
|
|
139
|
+
"seq-macro",
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
[[package]]
|
|
143
|
+
name = "magnus-macros"
|
|
144
|
+
version = "0.6.0"
|
|
145
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
146
|
+
checksum = "5968c820e2960565f647819f5928a42d6e874551cab9d88d75e3e0660d7f71e3"
|
|
147
|
+
dependencies = [
|
|
148
|
+
"proc-macro2",
|
|
149
|
+
"quote",
|
|
150
|
+
"syn",
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
[[package]]
|
|
154
|
+
name = "matrixmultiply"
|
|
155
|
+
version = "0.3.10"
|
|
156
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
157
|
+
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
|
158
|
+
dependencies = [
|
|
159
|
+
"autocfg",
|
|
160
|
+
"rawpointer",
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
[[package]]
|
|
164
|
+
name = "memchr"
|
|
165
|
+
version = "2.7.6"
|
|
166
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
167
|
+
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
|
168
|
+
|
|
169
|
+
[[package]]
|
|
170
|
+
name = "minimal-lexical"
|
|
171
|
+
version = "0.2.1"
|
|
172
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
173
|
+
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
|
174
|
+
|
|
175
|
+
[[package]]
|
|
176
|
+
name = "ndarray"
|
|
177
|
+
version = "0.15.6"
|
|
178
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
179
|
+
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
|
180
|
+
dependencies = [
|
|
181
|
+
"matrixmultiply",
|
|
182
|
+
"num-complex",
|
|
183
|
+
"num-integer",
|
|
184
|
+
"num-traits",
|
|
185
|
+
"rawpointer",
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
[[package]]
|
|
189
|
+
name = "nom"
|
|
190
|
+
version = "7.1.3"
|
|
191
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
192
|
+
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
|
193
|
+
dependencies = [
|
|
194
|
+
"memchr",
|
|
195
|
+
"minimal-lexical",
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
[[package]]
|
|
199
|
+
name = "num-complex"
|
|
200
|
+
version = "0.4.6"
|
|
201
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
202
|
+
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
|
203
|
+
dependencies = [
|
|
204
|
+
"num-traits",
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
[[package]]
|
|
208
|
+
name = "num-integer"
|
|
209
|
+
version = "0.1.46"
|
|
210
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
211
|
+
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
|
212
|
+
dependencies = [
|
|
213
|
+
"num-traits",
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
[[package]]
|
|
217
|
+
name = "num-traits"
|
|
218
|
+
version = "0.2.19"
|
|
219
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
220
|
+
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
|
221
|
+
dependencies = [
|
|
222
|
+
"autocfg",
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
[[package]]
|
|
226
|
+
name = "proc-macro2"
|
|
227
|
+
version = "1.0.106"
|
|
228
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
229
|
+
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
|
230
|
+
dependencies = [
|
|
231
|
+
"unicode-ident",
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
[[package]]
|
|
235
|
+
name = "quote"
|
|
236
|
+
version = "1.0.44"
|
|
237
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
238
|
+
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
|
|
239
|
+
dependencies = [
|
|
240
|
+
"proc-macro2",
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
[[package]]
|
|
244
|
+
name = "rawpointer"
|
|
245
|
+
version = "0.2.1"
|
|
246
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
247
|
+
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
|
248
|
+
|
|
249
|
+
[[package]]
|
|
250
|
+
name = "rb-sys"
|
|
251
|
+
version = "0.9.124"
|
|
252
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
253
|
+
checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489"
|
|
254
|
+
dependencies = [
|
|
255
|
+
"rb-sys-build",
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
[[package]]
|
|
259
|
+
name = "rb-sys-build"
|
|
260
|
+
version = "0.9.124"
|
|
261
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
262
|
+
checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3"
|
|
263
|
+
dependencies = [
|
|
264
|
+
"bindgen",
|
|
265
|
+
"lazy_static",
|
|
266
|
+
"proc-macro2",
|
|
267
|
+
"quote",
|
|
268
|
+
"regex",
|
|
269
|
+
"shell-words",
|
|
270
|
+
"syn",
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
[[package]]
|
|
274
|
+
name = "rb-sys-env"
|
|
275
|
+
version = "0.1.2"
|
|
276
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
277
|
+
checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb"
|
|
278
|
+
|
|
279
|
+
[[package]]
|
|
280
|
+
name = "regex"
|
|
281
|
+
version = "1.12.2"
|
|
282
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
283
|
+
checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4"
|
|
284
|
+
dependencies = [
|
|
285
|
+
"aho-corasick",
|
|
286
|
+
"memchr",
|
|
287
|
+
"regex-automata",
|
|
288
|
+
"regex-syntax",
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
[[package]]
|
|
292
|
+
name = "regex-automata"
|
|
293
|
+
version = "0.4.13"
|
|
294
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
295
|
+
checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
|
|
296
|
+
dependencies = [
|
|
297
|
+
"aho-corasick",
|
|
298
|
+
"memchr",
|
|
299
|
+
"regex-syntax",
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
[[package]]
|
|
303
|
+
name = "regex-syntax"
|
|
304
|
+
version = "0.8.8"
|
|
305
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
306
|
+
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
|
|
307
|
+
|
|
308
|
+
[[package]]
|
|
309
|
+
name = "rustc-hash"
|
|
310
|
+
version = "1.1.0"
|
|
311
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
312
|
+
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
|
313
|
+
|
|
314
|
+
[[package]]
|
|
315
|
+
name = "seq-macro"
|
|
316
|
+
version = "0.3.6"
|
|
317
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
318
|
+
checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc"
|
|
319
|
+
|
|
320
|
+
[[package]]
|
|
321
|
+
name = "shell-words"
|
|
322
|
+
version = "1.1.1"
|
|
323
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
324
|
+
checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
|
|
325
|
+
|
|
326
|
+
[[package]]
|
|
327
|
+
name = "shlex"
|
|
328
|
+
version = "1.3.0"
|
|
329
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
330
|
+
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
|
331
|
+
|
|
332
|
+
[[package]]
|
|
333
|
+
name = "syn"
|
|
334
|
+
version = "2.0.114"
|
|
335
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
336
|
+
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
|
|
337
|
+
dependencies = [
|
|
338
|
+
"proc-macro2",
|
|
339
|
+
"quote",
|
|
340
|
+
"unicode-ident",
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
[[package]]
|
|
344
|
+
name = "unicode-ident"
|
|
345
|
+
version = "1.0.22"
|
|
346
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
347
|
+
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
|
348
|
+
|
|
349
|
+
[[package]]
|
|
350
|
+
name = "windows-link"
|
|
351
|
+
version = "0.2.1"
|
|
352
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
353
|
+
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
use magnus::{prelude::*, Ruby, RArray, Value};
|
|
2
|
+
use ndarray::Array2;
|
|
3
|
+
use std::f64;
|
|
4
|
+
|
|
5
|
+
// === Flat MatMul ===
|
|
6
|
+
fn fast_matmul_flat(
|
|
7
|
+
a: Vec<f64>,
|
|
8
|
+
b: Vec<f64>,
|
|
9
|
+
m: usize,
|
|
10
|
+
k: usize,
|
|
11
|
+
n: usize
|
|
12
|
+
) -> Vec<f64> {
|
|
13
|
+
let a_arr = Array2::from_shape_vec((m, k), a).unwrap();
|
|
14
|
+
let b_arr = Array2::from_shape_vec((k, n), b).unwrap();
|
|
15
|
+
let c_arr = a_arr.dot(&b_arr);
|
|
16
|
+
c_arr.into_raw_vec()
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
// === Conv2D ===
|
|
20
|
+
fn conv2d_forward(
|
|
21
|
+
input: Vec<f64>,
|
|
22
|
+
weight: Vec<f64>,
|
|
23
|
+
bias: Vec<f64>,
|
|
24
|
+
batch: usize,
|
|
25
|
+
in_c: usize,
|
|
26
|
+
out_c: usize,
|
|
27
|
+
h: usize,
|
|
28
|
+
w: usize,
|
|
29
|
+
k: usize,
|
|
30
|
+
) -> Vec<f64> {
|
|
31
|
+
let h_out = h - k + 1;
|
|
32
|
+
let w_out = w - k + 1;
|
|
33
|
+
let mut output = vec![0.0; batch * out_c * h_out * w_out];
|
|
34
|
+
|
|
35
|
+
for b in 0..batch {
|
|
36
|
+
for oc in 0..out_c {
|
|
37
|
+
for i in 0..h_out {
|
|
38
|
+
for j in 0..w_out {
|
|
39
|
+
let mut sum = bias[oc];
|
|
40
|
+
for ic in 0..in_c {
|
|
41
|
+
for ki in 0..k {
|
|
42
|
+
for kj in 0..k {
|
|
43
|
+
let in_idx = ((b * in_c + ic) * h + i + ki) * w + j + kj;
|
|
44
|
+
let w_idx = ((oc * in_c + ic) * k + ki) * k + kj;
|
|
45
|
+
sum += input[in_idx] * weight[w_idx];
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
let out_idx = ((b * out_c + oc) * h_out + i) * w_out + j;
|
|
50
|
+
output[out_idx] = sum;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
output
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// === MaxPool2D ===
|
|
59
|
+
fn maxpool2d_forward(
|
|
60
|
+
input: Vec<f64>,
|
|
61
|
+
batch: usize,
|
|
62
|
+
channels: usize,
|
|
63
|
+
h: usize,
|
|
64
|
+
w: usize,
|
|
65
|
+
k: usize,
|
|
66
|
+
) -> Vec<f64> {
|
|
67
|
+
let h_out = h / k;
|
|
68
|
+
let w_out = w / k;
|
|
69
|
+
let mut output = vec![0.0; batch * channels * h_out * w_out];
|
|
70
|
+
|
|
71
|
+
for b in 0..batch {
|
|
72
|
+
for c in 0..channels {
|
|
73
|
+
for i in 0..h_out {
|
|
74
|
+
for j in 0..w_out {
|
|
75
|
+
let mut max_val = f64::MIN;
|
|
76
|
+
for ki in 0..k {
|
|
77
|
+
for kj in 0..k {
|
|
78
|
+
let in_idx = ((b * channels + c) * h + i * k + ki) * w + j * k + kj;
|
|
79
|
+
max_val = max_val.max(input[in_idx]);
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
let out_idx = ((b * channels + c) * h_out + i) * w_out + j;
|
|
83
|
+
output[out_idx] = max_val;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
output
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// === Инициализация ===
|
|
92
|
+
#[magnus::init]
|
|
93
|
+
fn init(ruby: &Ruby) -> Result<(), magnus::Error> {
|
|
94
|
+
let module = ruby.define_module("Daimond")?
|
|
95
|
+
.define_module("Rust")?;
|
|
96
|
+
|
|
97
|
+
// Только flat-версии (быстрее и проще)
|
|
98
|
+
module.define_singleton_method("fast_matmul_flat", magnus::function!(fast_matmul_flat, 5))?;
|
|
99
|
+
module.define_singleton_method("conv2d_native", magnus::function!(conv2d_forward, 9))?;
|
|
100
|
+
module.define_singleton_method("maxpool2d_native", magnus::function!(maxpool2d_forward, 6))?;
|
|
101
|
+
|
|
102
|
+
Ok(())
|
|
103
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
module Daimond
|
|
2
|
+
module Data
|
|
3
|
+
class DataLoader
|
|
4
|
+
def initialize(images, labels, batch_size: 32, shuffle: true)
|
|
5
|
+
@images = images
|
|
6
|
+
@labels = labels
|
|
7
|
+
@batch_size = batch_size
|
|
8
|
+
@shuffle = shuffle
|
|
9
|
+
@n_samples = images.length
|
|
10
|
+
reset
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def reset
|
|
14
|
+
@indices = (0...@n_samples).to_a
|
|
15
|
+
@indices.shuffle! if @shuffle
|
|
16
|
+
@position = 0
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def each_batch
|
|
20
|
+
reset
|
|
21
|
+
while @position < @n_samples
|
|
22
|
+
batch_indices = @indices[@position, @batch_size]
|
|
23
|
+
@position += @batch_size
|
|
24
|
+
|
|
25
|
+
batch_images = batch_indices.map { |i| @images[i] }
|
|
26
|
+
batch_labels = batch_indices.map { |i| @labels[i] }
|
|
27
|
+
|
|
28
|
+
# Конвертируем в Tensor [batch_size, 784]
|
|
29
|
+
x = Tensor.new(Numo::DFloat[*batch_images])
|
|
30
|
+
y = Tensor.new(Numo::Int32[*batch_labels])
|
|
31
|
+
|
|
32
|
+
yield x, y
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def batches_count
|
|
37
|
+
(@n_samples.to_f / @batch_size).ceil
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
end
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
require 'open-uri'
|
|
2
|
+
require 'zlib'
|
|
3
|
+
require 'fileutils'
|
|
4
|
+
|
|
5
|
+
module Daimond
|
|
6
|
+
module Data
|
|
7
|
+
class MNIST
|
|
8
|
+
URL_BASE = 'https://ossci-datasets.s3.amazonaws.com/mnist/'
|
|
9
|
+
FILES = {
|
|
10
|
+
train_images: 'train-images-idx3-ubyte.gz',
|
|
11
|
+
train_labels: 'train-labels-idx1-ubyte.gz',
|
|
12
|
+
test_images: 't10k-images-idx3-ubyte.gz',
|
|
13
|
+
test_labels: 't10k-labels-idx1-ubyte.gz'
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
def self.download(file, path: 'data/mnist')
|
|
17
|
+
::FileUtils.mkdir_p(path) # <-- Здесь изменение: создаёт и data, и data/mnist
|
|
18
|
+
filepath = File.join(path, file)
|
|
19
|
+
|
|
20
|
+
unless File.exist?(filepath)
|
|
21
|
+
puts "Downloading #{file}..."
|
|
22
|
+
URI.open("#{URL_BASE}#{file}") do |f|
|
|
23
|
+
File.open(filepath, 'wb') { |out| out.write(f.read) }
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
filepath
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def self.load_images(file)
|
|
30
|
+
filepath = download(file)
|
|
31
|
+
Zlib::GzipReader.open(filepath) do |f|
|
|
32
|
+
magic = f.read(4).unpack('N').first
|
|
33
|
+
n_images = f.read(4).unpack('N').first
|
|
34
|
+
n_rows = f.read(4).unpack('N').first
|
|
35
|
+
n_cols = f.read(4).unpack('N').first
|
|
36
|
+
|
|
37
|
+
images = f.read(n_images * n_rows * n_cols).unpack('C*')
|
|
38
|
+
# Нормализация в [0, 1] и reshape в [n_images, 784]
|
|
39
|
+
images.each_slice(n_rows * n_cols).map do |img|
|
|
40
|
+
img.map { |pixel| pixel / 255.0 }
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def self.load_labels(file)
|
|
46
|
+
filepath = download(file)
|
|
47
|
+
Zlib::GzipReader.open(filepath) do |f|
|
|
48
|
+
magic = f.read(4).unpack('N').first
|
|
49
|
+
n_labels = f.read(4).unpack('N').first
|
|
50
|
+
labels = f.read(n_labels).unpack('C*')
|
|
51
|
+
labels
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
require_relative '../tensor'
|
|
2
|
+
require 'numo/narray'
|
|
3
|
+
|
|
4
|
+
module Daimond
|
|
5
|
+
module Loss
|
|
6
|
+
class CrossEntropyLoss
|
|
7
|
+
def initialize
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def forward(pred, target)
|
|
11
|
+
# pred: [batch_size, 10] после softmax
|
|
12
|
+
# target: [batch_size] метки классов (0-9)
|
|
13
|
+
batch_size = pred.shape[0]
|
|
14
|
+
|
|
15
|
+
# Вычисляем loss для мониторинга (не используется в backward)
|
|
16
|
+
log_probs = Numo::NMath.log(pred.data)
|
|
17
|
+
correct_log_probs = Numo::DFloat.zeros(batch_size)
|
|
18
|
+
|
|
19
|
+
batch_size.times do |i|
|
|
20
|
+
correct_log_probs[i] = log_probs[i, target.data[i]]
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
loss_value = -correct_log_probs.mean
|
|
24
|
+
|
|
25
|
+
out = Tensor.new(Numo::DFloat[loss_value], prev: [pred], op: 'cross_entropy')
|
|
26
|
+
|
|
27
|
+
# Backward: gradient of cross_entropy + softmax = pred - one_hot(target)
|
|
28
|
+
out._backward = lambda do
|
|
29
|
+
grad_input = pred.data.dup # softmax output
|
|
30
|
+
batch_size.times do |i|
|
|
31
|
+
grad_input[i, target.data[i]] -= 1.0
|
|
32
|
+
end
|
|
33
|
+
grad_input /= batch_size
|
|
34
|
+
pred.grad += grad_input
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
out
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def call(pred, target)
|
|
41
|
+
forward(pred, target)
|
|
42
|
+
end
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
end
|
|
File without changes
|