rumale-torch 0.1.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +0 -1
- data/lib/rumale/torch/neural_net_classifier.rb +4 -4
- data/lib/rumale/torch/neural_net_regressor.rb +6 -4
- data/lib/rumale/torch/version.rb +1 -1
- metadata +32 -13
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3e192fe0e764bd3fb207eb742f9017b38c09d1f08ded26b7f40d742aa5e4e70f
|
4
|
+
data.tar.gz: 40b0f1a4b7e8f6528195e2399948607059582b15d5198352a731bfab2352ea97
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 31471cac43e90befc16e535da721db81d1626a5e0798704117365fed579e99a177725f5fe03bb115b4949e9bff71c7d682f4db5be35dbcc214b6f6587e47d5b8
|
7
|
+
data.tar.gz: 1498f17f305ec390664f140df4bb97dbca1cc3a690948229d34d9b45d4e160b436fe38354b6bf50c9deac25a43c80240dd576efbbf1f1956310cf39bfc8464d7
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
# Rumale::Torch
|
2
2
|
|
3
3
|
[](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
|
4
|
-
[](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
|
5
4
|
[](https://badge.fury.io/rb/rumale-torch)
|
6
5
|
[](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
|
7
6
|
[](https://yoshoku.github.io/rumale-torch/doc/)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/classifier'
|
5
5
|
require 'rumale/preprocessing/label_encoder'
|
6
6
|
require 'rumale/model_selection/function'
|
@@ -38,9 +38,8 @@ module Rumale
|
|
38
38
|
#
|
39
39
|
# classifier.predict(x)
|
40
40
|
#
|
41
|
-
class NeuralNetClassifier
|
42
|
-
include Base::
|
43
|
-
include Base::Classifier
|
41
|
+
class NeuralNetClassifier < Rumale::Base::Estimator
|
42
|
+
include Rumale::Base::Classifier
|
44
43
|
|
45
44
|
# Return the class labels.
|
46
45
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -80,6 +79,7 @@ module Rumale
|
|
80
79
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
81
80
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
82
81
|
verbose: false, random_seed: nil)
|
82
|
+
super()
|
83
83
|
@model = model
|
84
84
|
@device = device || ::Torch.device('cpu')
|
85
85
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/regressor'
|
5
5
|
require 'rumale/model_selection/shuffle_split'
|
6
6
|
|
@@ -36,9 +36,8 @@ module Rumale
|
|
36
36
|
#
|
37
37
|
# regressor.predict(x)
|
38
38
|
#
|
39
|
-
class NeuralNetRegressor
|
40
|
-
include Base::
|
41
|
-
include Base::Regressor
|
39
|
+
class NeuralNetRegressor < Rumale::Base::Estimator
|
40
|
+
include Rumale::Base::Regressor
|
42
41
|
|
43
42
|
# Return the neural nets defined with torch.rb.
|
44
43
|
# @return [Torch::NN::Module]
|
@@ -74,6 +73,7 @@ module Rumale
|
|
74
73
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
75
74
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
76
75
|
verbose: false, random_seed: nil)
|
76
|
+
super()
|
77
77
|
@model = model
|
78
78
|
@device = device || ::Torch.device('cpu')
|
79
79
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -179,12 +179,14 @@ module Rumale
|
|
179
179
|
end
|
180
180
|
|
181
181
|
def display_epoch(train_loader, test_loader, epoch)
|
182
|
+
# rubocop:disable Lint/FormatParameterMismatch
|
182
183
|
if test_loader.nil?
|
183
184
|
puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f", epoch, evaluate(train_loader)))
|
184
185
|
else
|
185
186
|
puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f - val_loss: %.4f",
|
186
187
|
epoch, evaluate(train_loader), evaluate(test_loader)))
|
187
188
|
end
|
189
|
+
# rubocop:enable Lint/FormatParameterMismatch
|
188
190
|
end
|
189
191
|
|
190
192
|
def evaluate(data_loader)
|
data/lib/rumale/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,33 +1,54 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-torch
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
|
-
name: rumale
|
13
|
+
name: rumale-core
|
15
14
|
requirement: !ruby/object:Gem::Requirement
|
16
15
|
requirements:
|
17
|
-
- - "
|
16
|
+
- - ">="
|
18
17
|
- !ruby/object:Gem::Version
|
19
|
-
version: '0.
|
20
|
-
|
18
|
+
version: '0.24'
|
19
|
+
type: :runtime
|
20
|
+
prerelease: false
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
22
|
+
requirements:
|
23
|
+
- - ">="
|
24
|
+
- !ruby/object:Gem::Version
|
25
|
+
version: '0.24'
|
26
|
+
- !ruby/object:Gem::Dependency
|
27
|
+
name: rumale-model_selection
|
28
|
+
requirement: !ruby/object:Gem::Requirement
|
29
|
+
requirements:
|
30
|
+
- - ">="
|
21
31
|
- !ruby/object:Gem::Version
|
22
32
|
version: '0.24'
|
23
33
|
type: :runtime
|
24
34
|
prerelease: false
|
25
35
|
version_requirements: !ruby/object:Gem::Requirement
|
26
36
|
requirements:
|
27
|
-
- - "
|
37
|
+
- - ">="
|
38
|
+
- !ruby/object:Gem::Version
|
39
|
+
version: '0.24'
|
40
|
+
- !ruby/object:Gem::Dependency
|
41
|
+
name: rumale-preprocessing
|
42
|
+
requirement: !ruby/object:Gem::Requirement
|
43
|
+
requirements:
|
44
|
+
- - ">="
|
28
45
|
- !ruby/object:Gem::Version
|
29
|
-
version: '0.
|
30
|
-
|
46
|
+
version: '0.24'
|
47
|
+
type: :runtime
|
48
|
+
prerelease: false
|
49
|
+
version_requirements: !ruby/object:Gem::Requirement
|
50
|
+
requirements:
|
51
|
+
- - ">="
|
31
52
|
- !ruby/object:Gem::Version
|
32
53
|
version: '0.24'
|
33
54
|
- !ruby/object:Gem::Dependency
|
@@ -70,7 +91,6 @@ metadata:
|
|
70
91
|
changelog_uri: https://github.com/yoshoku/rumale-torch/blob/main/CHANGELOG.md
|
71
92
|
documentation_uri: https://yoshoku.github.io/rumale-torch/doc/
|
72
93
|
rubygems_mfa_required: 'true'
|
73
|
-
post_install_message:
|
74
94
|
rdoc_options: []
|
75
95
|
require_paths:
|
76
96
|
- lib
|
@@ -85,8 +105,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
105
|
- !ruby/object:Gem::Version
|
86
106
|
version: '0'
|
87
107
|
requirements: []
|
88
|
-
rubygems_version: 3.2
|
89
|
-
signing_key:
|
108
|
+
rubygems_version: 3.6.2
|
90
109
|
specification_version: 4
|
91
110
|
summary: Rumale::Torch provides the learning and inference by the neural network defined
|
92
111
|
in torch.rb with the same interface as Rumale
|