daimond 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,353 @@
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "aho-corasick"
7
+ version = "1.1.4"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
10
+ dependencies = [
11
+ "memchr",
12
+ ]
13
+
14
+ [[package]]
15
+ name = "autocfg"
16
+ version = "1.5.0"
17
+ source = "registry+https://github.com/rust-lang/crates.io-index"
18
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
19
+
20
+ [[package]]
21
+ name = "bindgen"
22
+ version = "0.69.5"
23
+ source = "registry+https://github.com/rust-lang/crates.io-index"
24
+ checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
25
+ dependencies = [
26
+ "bitflags",
27
+ "cexpr",
28
+ "clang-sys",
29
+ "itertools",
30
+ "lazy_static",
31
+ "lazycell",
32
+ "proc-macro2",
33
+ "quote",
34
+ "regex",
35
+ "rustc-hash",
36
+ "shlex",
37
+ "syn",
38
+ ]
39
+
40
+ [[package]]
41
+ name = "bitflags"
42
+ version = "2.10.0"
43
+ source = "registry+https://github.com/rust-lang/crates.io-index"
44
+ checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
45
+
46
+ [[package]]
47
+ name = "cexpr"
48
+ version = "0.6.0"
49
+ source = "registry+https://github.com/rust-lang/crates.io-index"
50
+ checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
51
+ dependencies = [
52
+ "nom",
53
+ ]
54
+
55
+ [[package]]
56
+ name = "cfg-if"
57
+ version = "1.0.4"
58
+ source = "registry+https://github.com/rust-lang/crates.io-index"
59
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
60
+
61
+ [[package]]
62
+ name = "clang-sys"
63
+ version = "1.8.1"
64
+ source = "registry+https://github.com/rust-lang/crates.io-index"
65
+ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
66
+ dependencies = [
67
+ "glob",
68
+ "libc",
69
+ "libloading",
70
+ ]
71
+
72
+ [[package]]
73
+ name = "daimond_rust"
74
+ version = "0.1.0"
75
+ dependencies = [
76
+ "magnus",
77
+ "ndarray",
78
+ "rb-sys",
79
+ ]
80
+
81
+ [[package]]
82
+ name = "either"
83
+ version = "1.15.0"
84
+ source = "registry+https://github.com/rust-lang/crates.io-index"
85
+ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
86
+
87
+ [[package]]
88
+ name = "glob"
89
+ version = "0.3.3"
90
+ source = "registry+https://github.com/rust-lang/crates.io-index"
91
+ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
92
+
93
+ [[package]]
94
+ name = "itertools"
95
+ version = "0.12.1"
96
+ source = "registry+https://github.com/rust-lang/crates.io-index"
97
+ checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
98
+ dependencies = [
99
+ "either",
100
+ ]
101
+
102
+ [[package]]
103
+ name = "lazy_static"
104
+ version = "1.5.0"
105
+ source = "registry+https://github.com/rust-lang/crates.io-index"
106
+ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
107
+
108
+ [[package]]
109
+ name = "lazycell"
110
+ version = "1.3.0"
111
+ source = "registry+https://github.com/rust-lang/crates.io-index"
112
+ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
113
+
114
+ [[package]]
115
+ name = "libc"
116
+ version = "0.2.180"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
119
+
120
+ [[package]]
121
+ name = "libloading"
122
+ version = "0.8.9"
123
+ source = "registry+https://github.com/rust-lang/crates.io-index"
124
+ checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
125
+ dependencies = [
126
+ "cfg-if",
127
+ "windows-link",
128
+ ]
129
+
130
+ [[package]]
131
+ name = "magnus"
132
+ version = "0.6.4"
133
+ source = "registry+https://github.com/rust-lang/crates.io-index"
134
+ checksum = "b1597ef40aa8c36be098249e82c9a20cf7199278ac1c1a1a995eeead6a184479"
135
+ dependencies = [
136
+ "magnus-macros",
137
+ "rb-sys",
138
+ "rb-sys-env",
139
+ "seq-macro",
140
+ ]
141
+
142
+ [[package]]
143
+ name = "magnus-macros"
144
+ version = "0.6.0"
145
+ source = "registry+https://github.com/rust-lang/crates.io-index"
146
+ checksum = "5968c820e2960565f647819f5928a42d6e874551cab9d88d75e3e0660d7f71e3"
147
+ dependencies = [
148
+ "proc-macro2",
149
+ "quote",
150
+ "syn",
151
+ ]
152
+
153
+ [[package]]
154
+ name = "matrixmultiply"
155
+ version = "0.3.10"
156
+ source = "registry+https://github.com/rust-lang/crates.io-index"
157
+ checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
158
+ dependencies = [
159
+ "autocfg",
160
+ "rawpointer",
161
+ ]
162
+
163
+ [[package]]
164
+ name = "memchr"
165
+ version = "2.7.6"
166
+ source = "registry+https://github.com/rust-lang/crates.io-index"
167
+ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
168
+
169
+ [[package]]
170
+ name = "minimal-lexical"
171
+ version = "0.2.1"
172
+ source = "registry+https://github.com/rust-lang/crates.io-index"
173
+ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
174
+
175
+ [[package]]
176
+ name = "ndarray"
177
+ version = "0.15.6"
178
+ source = "registry+https://github.com/rust-lang/crates.io-index"
179
+ checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
180
+ dependencies = [
181
+ "matrixmultiply",
182
+ "num-complex",
183
+ "num-integer",
184
+ "num-traits",
185
+ "rawpointer",
186
+ ]
187
+
188
+ [[package]]
189
+ name = "nom"
190
+ version = "7.1.3"
191
+ source = "registry+https://github.com/rust-lang/crates.io-index"
192
+ checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
193
+ dependencies = [
194
+ "memchr",
195
+ "minimal-lexical",
196
+ ]
197
+
198
+ [[package]]
199
+ name = "num-complex"
200
+ version = "0.4.6"
201
+ source = "registry+https://github.com/rust-lang/crates.io-index"
202
+ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
203
+ dependencies = [
204
+ "num-traits",
205
+ ]
206
+
207
+ [[package]]
208
+ name = "num-integer"
209
+ version = "0.1.46"
210
+ source = "registry+https://github.com/rust-lang/crates.io-index"
211
+ checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
212
+ dependencies = [
213
+ "num-traits",
214
+ ]
215
+
216
+ [[package]]
217
+ name = "num-traits"
218
+ version = "0.2.19"
219
+ source = "registry+https://github.com/rust-lang/crates.io-index"
220
+ checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
221
+ dependencies = [
222
+ "autocfg",
223
+ ]
224
+
225
+ [[package]]
226
+ name = "proc-macro2"
227
+ version = "1.0.106"
228
+ source = "registry+https://github.com/rust-lang/crates.io-index"
229
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
230
+ dependencies = [
231
+ "unicode-ident",
232
+ ]
233
+
234
+ [[package]]
235
+ name = "quote"
236
+ version = "1.0.44"
237
+ source = "registry+https://github.com/rust-lang/crates.io-index"
238
+ checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
239
+ dependencies = [
240
+ "proc-macro2",
241
+ ]
242
+
243
+ [[package]]
244
+ name = "rawpointer"
245
+ version = "0.2.1"
246
+ source = "registry+https://github.com/rust-lang/crates.io-index"
247
+ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
248
+
249
+ [[package]]
250
+ name = "rb-sys"
251
+ version = "0.9.124"
252
+ source = "registry+https://github.com/rust-lang/crates.io-index"
253
+ checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489"
254
+ dependencies = [
255
+ "rb-sys-build",
256
+ ]
257
+
258
+ [[package]]
259
+ name = "rb-sys-build"
260
+ version = "0.9.124"
261
+ source = "registry+https://github.com/rust-lang/crates.io-index"
262
+ checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3"
263
+ dependencies = [
264
+ "bindgen",
265
+ "lazy_static",
266
+ "proc-macro2",
267
+ "quote",
268
+ "regex",
269
+ "shell-words",
270
+ "syn",
271
+ ]
272
+
273
+ [[package]]
274
+ name = "rb-sys-env"
275
+ version = "0.1.2"
276
+ source = "registry+https://github.com/rust-lang/crates.io-index"
277
+ checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb"
278
+
279
+ [[package]]
280
+ name = "regex"
281
+ version = "1.12.2"
282
+ source = "registry+https://github.com/rust-lang/crates.io-index"
283
+ checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4"
284
+ dependencies = [
285
+ "aho-corasick",
286
+ "memchr",
287
+ "regex-automata",
288
+ "regex-syntax",
289
+ ]
290
+
291
+ [[package]]
292
+ name = "regex-automata"
293
+ version = "0.4.13"
294
+ source = "registry+https://github.com/rust-lang/crates.io-index"
295
+ checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
296
+ dependencies = [
297
+ "aho-corasick",
298
+ "memchr",
299
+ "regex-syntax",
300
+ ]
301
+
302
+ [[package]]
303
+ name = "regex-syntax"
304
+ version = "0.8.8"
305
+ source = "registry+https://github.com/rust-lang/crates.io-index"
306
+ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
307
+
308
+ [[package]]
309
+ name = "rustc-hash"
310
+ version = "1.1.0"
311
+ source = "registry+https://github.com/rust-lang/crates.io-index"
312
+ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
313
+
314
+ [[package]]
315
+ name = "seq-macro"
316
+ version = "0.3.6"
317
+ source = "registry+https://github.com/rust-lang/crates.io-index"
318
+ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc"
319
+
320
+ [[package]]
321
+ name = "shell-words"
322
+ version = "1.1.1"
323
+ source = "registry+https://github.com/rust-lang/crates.io-index"
324
+ checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
325
+
326
+ [[package]]
327
+ name = "shlex"
328
+ version = "1.3.0"
329
+ source = "registry+https://github.com/rust-lang/crates.io-index"
330
+ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
331
+
332
+ [[package]]
333
+ name = "syn"
334
+ version = "2.0.114"
335
+ source = "registry+https://github.com/rust-lang/crates.io-index"
336
+ checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
337
+ dependencies = [
338
+ "proc-macro2",
339
+ "quote",
340
+ "unicode-ident",
341
+ ]
342
+
343
+ [[package]]
344
+ name = "unicode-ident"
345
+ version = "1.0.22"
346
+ source = "registry+https://github.com/rust-lang/crates.io-index"
347
+ checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
348
+
349
+ [[package]]
350
+ name = "windows-link"
351
+ version = "0.2.1"
352
+ source = "registry+https://github.com/rust-lang/crates.io-index"
353
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
@@ -0,0 +1,13 @@
1
+ [package]
2
+ name = "daimond_rust"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+ build = "build.rs"
6
+
7
+ [lib]
8
+ crate-type = ["cdylib"]
9
+
10
+ [dependencies]
11
+ magnus = { version = "0.6", features = ["rb-sys"] }
12
+ rb-sys = "0.9"
13
+ ndarray = "0.15"
@@ -0,0 +1,3 @@
1
+ fn main() {
2
+ // Автоматически находит Ruby
3
+ }
@@ -0,0 +1,103 @@
1
+ use magnus::{prelude::*, Ruby, RArray, Value};
2
+ use ndarray::Array2;
3
+ use std::f64;
4
+
5
+ // === Flat MatMul ===
6
+ fn fast_matmul_flat(
7
+ a: Vec<f64>,
8
+ b: Vec<f64>,
9
+ m: usize,
10
+ k: usize,
11
+ n: usize
12
+ ) -> Vec<f64> {
13
+ let a_arr = Array2::from_shape_vec((m, k), a).unwrap();
14
+ let b_arr = Array2::from_shape_vec((k, n), b).unwrap();
15
+ let c_arr = a_arr.dot(&b_arr);
16
+ c_arr.into_raw_vec()
17
+ }
18
+
19
+ // === Conv2D ===
20
+ fn conv2d_forward(
21
+ input: Vec<f64>,
22
+ weight: Vec<f64>,
23
+ bias: Vec<f64>,
24
+ batch: usize,
25
+ in_c: usize,
26
+ out_c: usize,
27
+ h: usize,
28
+ w: usize,
29
+ k: usize,
30
+ ) -> Vec<f64> {
31
+ let h_out = h - k + 1;
32
+ let w_out = w - k + 1;
33
+ let mut output = vec![0.0; batch * out_c * h_out * w_out];
34
+
35
+ for b in 0..batch {
36
+ for oc in 0..out_c {
37
+ for i in 0..h_out {
38
+ for j in 0..w_out {
39
+ let mut sum = bias[oc];
40
+ for ic in 0..in_c {
41
+ for ki in 0..k {
42
+ for kj in 0..k {
43
+ let in_idx = ((b * in_c + ic) * h + i + ki) * w + j + kj;
44
+ let w_idx = ((oc * in_c + ic) * k + ki) * k + kj;
45
+ sum += input[in_idx] * weight[w_idx];
46
+ }
47
+ }
48
+ }
49
+ let out_idx = ((b * out_c + oc) * h_out + i) * w_out + j;
50
+ output[out_idx] = sum;
51
+ }
52
+ }
53
+ }
54
+ }
55
+ output
56
+ }
57
+
58
+ // === MaxPool2D ===
59
+ fn maxpool2d_forward(
60
+ input: Vec<f64>,
61
+ batch: usize,
62
+ channels: usize,
63
+ h: usize,
64
+ w: usize,
65
+ k: usize,
66
+ ) -> Vec<f64> {
67
+ let h_out = h / k;
68
+ let w_out = w / k;
69
+ let mut output = vec![0.0; batch * channels * h_out * w_out];
70
+
71
+ for b in 0..batch {
72
+ for c in 0..channels {
73
+ for i in 0..h_out {
74
+ for j in 0..w_out {
75
+ let mut max_val = f64::MIN;
76
+ for ki in 0..k {
77
+ for kj in 0..k {
78
+ let in_idx = ((b * channels + c) * h + i * k + ki) * w + j * k + kj;
79
+ max_val = max_val.max(input[in_idx]);
80
+ }
81
+ }
82
+ let out_idx = ((b * channels + c) * h_out + i) * w_out + j;
83
+ output[out_idx] = max_val;
84
+ }
85
+ }
86
+ }
87
+ }
88
+ output
89
+ }
90
+
91
+ // === Инициализация ===
92
+ #[magnus::init]
93
+ fn init(ruby: &Ruby) -> Result<(), magnus::Error> {
94
+ let module = ruby.define_module("Daimond")?
95
+ .define_module("Rust")?;
96
+
97
+ // Только flat-версии (быстрее и проще)
98
+ module.define_singleton_method("fast_matmul_flat", magnus::function!(fast_matmul_flat, 5))?;
99
+ module.define_singleton_method("conv2d_native", magnus::function!(conv2d_forward, 9))?;
100
+ module.define_singleton_method("maxpool2d_native", magnus::function!(maxpool2d_forward, 6))?;
101
+
102
+ Ok(())
103
+ }
File without changes
@@ -0,0 +1,41 @@
1
+ module Daimond
2
+ module Data
3
+ class DataLoader
4
+ def initialize(images, labels, batch_size: 32, shuffle: true)
5
+ @images = images
6
+ @labels = labels
7
+ @batch_size = batch_size
8
+ @shuffle = shuffle
9
+ @n_samples = images.length
10
+ reset
11
+ end
12
+
13
+ def reset
14
+ @indices = (0...@n_samples).to_a
15
+ @indices.shuffle! if @shuffle
16
+ @position = 0
17
+ end
18
+
19
+ def each_batch
20
+ reset
21
+ while @position < @n_samples
22
+ batch_indices = @indices[@position, @batch_size]
23
+ @position += @batch_size
24
+
25
+ batch_images = batch_indices.map { |i| @images[i] }
26
+ batch_labels = batch_indices.map { |i| @labels[i] }
27
+
28
+ # Конвертируем в Tensor [batch_size, 784]
29
+ x = Tensor.new(Numo::DFloat[*batch_images])
30
+ y = Tensor.new(Numo::Int32[*batch_labels])
31
+
32
+ yield x, y
33
+ end
34
+ end
35
+
36
+ def batches_count
37
+ (@n_samples.to_f / @batch_size).ceil
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,56 @@
1
+ require 'open-uri'
2
+ require 'zlib'
3
+ require 'fileutils'
4
+
5
+ module Daimond
6
+ module Data
7
+ class MNIST
8
+ URL_BASE = 'https://ossci-datasets.s3.amazonaws.com/mnist/'
9
+ FILES = {
10
+ train_images: 'train-images-idx3-ubyte.gz',
11
+ train_labels: 'train-labels-idx1-ubyte.gz',
12
+ test_images: 't10k-images-idx3-ubyte.gz',
13
+ test_labels: 't10k-labels-idx1-ubyte.gz'
14
+ }
15
+
16
+ def self.download(file, path: 'data/mnist')
17
+ ::FileUtils.mkdir_p(path) # <-- Здесь изменение: создаёт и data, и data/mnist
18
+ filepath = File.join(path, file)
19
+
20
+ unless File.exist?(filepath)
21
+ puts "Downloading #{file}..."
22
+ URI.open("#{URL_BASE}#{file}") do |f|
23
+ File.open(filepath, 'wb') { |out| out.write(f.read) }
24
+ end
25
+ end
26
+ filepath
27
+ end
28
+
29
+ def self.load_images(file)
30
+ filepath = download(file)
31
+ Zlib::GzipReader.open(filepath) do |f|
32
+ magic = f.read(4).unpack('N').first
33
+ n_images = f.read(4).unpack('N').first
34
+ n_rows = f.read(4).unpack('N').first
35
+ n_cols = f.read(4).unpack('N').first
36
+
37
+ images = f.read(n_images * n_rows * n_cols).unpack('C*')
38
+ # Нормализация в [0, 1] и reshape в [n_images, 784]
39
+ images.each_slice(n_rows * n_cols).map do |img|
40
+ img.map { |pixel| pixel / 255.0 }
41
+ end
42
+ end
43
+ end
44
+
45
+ def self.load_labels(file)
46
+ filepath = download(file)
47
+ Zlib::GzipReader.open(filepath) do |f|
48
+ magic = f.read(4).unpack('N').first
49
+ n_labels = f.read(4).unpack('N').first
50
+ labels = f.read(n_labels).unpack('C*')
51
+ labels
52
+ end
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,45 @@
1
+ require_relative '../tensor'
2
+ require 'numo/narray'
3
+
4
+ module Daimond
5
+ module Loss
6
+ class CrossEntropyLoss
7
+ def initialize
8
+ end
9
+
10
+ def forward(pred, target)
11
+ # pred: [batch_size, 10] после softmax
12
+ # target: [batch_size] метки классов (0-9)
13
+ batch_size = pred.shape[0]
14
+
15
+ # Вычисляем loss для мониторинга (не используется в backward)
16
+ log_probs = Numo::NMath.log(pred.data)
17
+ correct_log_probs = Numo::DFloat.zeros(batch_size)
18
+
19
+ batch_size.times do |i|
20
+ correct_log_probs[i] = log_probs[i, target.data[i]]
21
+ end
22
+
23
+ loss_value = -correct_log_probs.mean
24
+
25
+ out = Tensor.new(Numo::DFloat[loss_value], prev: [pred], op: 'cross_entropy')
26
+
27
+ # Backward: gradient of cross_entropy + softmax = pred - one_hot(target)
28
+ out._backward = lambda do
29
+ grad_input = pred.data.dup # softmax output
30
+ batch_size.times do |i|
31
+ grad_input[i, target.data[i]] -= 1.0
32
+ end
33
+ grad_input /= batch_size
34
+ pred.grad += grad_input
35
+ end
36
+
37
+ out
38
+ end
39
+
40
+ def call(pred, target)
41
+ forward(pred, target)
42
+ end
43
+ end
44
+ end
45
+ end
File without changes