rumale-torch 0.1.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +0 -1
- data/lib/rumale/torch/neural_net_classifier.rb +4 -4
- data/lib/rumale/torch/neural_net_regressor.rb +6 -4
- data/lib/rumale/torch/version.rb +1 -1
- metadata +32 -13
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3e192fe0e764bd3fb207eb742f9017b38c09d1f08ded26b7f40d742aa5e4e70f
|
4
|
+
data.tar.gz: 40b0f1a4b7e8f6528195e2399948607059582b15d5198352a731bfab2352ea97
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 31471cac43e90befc16e535da721db81d1626a5e0798704117365fed579e99a177725f5fe03bb115b4949e9bff71c7d682f4db5be35dbcc214b6f6587e47d5b8
|
7
|
+
data.tar.gz: 1498f17f305ec390664f140df4bb97dbca1cc3a690948229d34d9b45d4e160b436fe38354b6bf50c9deac25a43c80240dd576efbbf1f1956310cf39bfc8464d7
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
# Rumale::Torch
|
2
2
|
|
3
3
|
[![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
|
4
|
-
[![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
|
5
4
|
[![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
|
6
5
|
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
|
7
6
|
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/classifier'
|
5
5
|
require 'rumale/preprocessing/label_encoder'
|
6
6
|
require 'rumale/model_selection/function'
|
@@ -38,9 +38,8 @@ module Rumale
|
|
38
38
|
#
|
39
39
|
# classifier.predict(x)
|
40
40
|
#
|
41
|
-
class NeuralNetClassifier
|
42
|
-
include Base::
|
43
|
-
include Base::Classifier
|
41
|
+
class NeuralNetClassifier < Rumale::Base::Estimator
|
42
|
+
include Rumale::Base::Classifier
|
44
43
|
|
45
44
|
# Return the class labels.
|
46
45
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -80,6 +79,7 @@ module Rumale
|
|
80
79
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
81
80
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
82
81
|
verbose: false, random_seed: nil)
|
82
|
+
super()
|
83
83
|
@model = model
|
84
84
|
@device = device || ::Torch.device('cpu')
|
85
85
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/regressor'
|
5
5
|
require 'rumale/model_selection/shuffle_split'
|
6
6
|
|
@@ -36,9 +36,8 @@ module Rumale
|
|
36
36
|
#
|
37
37
|
# regressor.predict(x)
|
38
38
|
#
|
39
|
-
class NeuralNetRegressor
|
40
|
-
include Base::
|
41
|
-
include Base::Regressor
|
39
|
+
class NeuralNetRegressor < Rumale::Base::Estimator
|
40
|
+
include Rumale::Base::Regressor
|
42
41
|
|
43
42
|
# Return the neural nets defined with torch.rb.
|
44
43
|
# @return [Torch::NN::Module]
|
@@ -74,6 +73,7 @@ module Rumale
|
|
74
73
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
75
74
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
76
75
|
verbose: false, random_seed: nil)
|
76
|
+
super()
|
77
77
|
@model = model
|
78
78
|
@device = device || ::Torch.device('cpu')
|
79
79
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -179,12 +179,14 @@ module Rumale
|
|
179
179
|
end
|
180
180
|
|
181
181
|
def display_epoch(train_loader, test_loader, epoch)
|
182
|
+
# rubocop:disable Lint/FormatParameterMismatch
|
182
183
|
if test_loader.nil?
|
183
184
|
puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f", epoch, evaluate(train_loader)))
|
184
185
|
else
|
185
186
|
puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f - val_loss: %.4f",
|
186
187
|
epoch, evaluate(train_loader), evaluate(test_loader)))
|
187
188
|
end
|
189
|
+
# rubocop:enable Lint/FormatParameterMismatch
|
188
190
|
end
|
189
191
|
|
190
192
|
def evaluate(data_loader)
|
data/lib/rumale/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,33 +1,54 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-torch
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
|
-
autorequire:
|
9
8
|
bindir: exe
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-02 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
|
-
name: rumale
|
13
|
+
name: rumale-core
|
15
14
|
requirement: !ruby/object:Gem::Requirement
|
16
15
|
requirements:
|
17
|
-
- - "
|
16
|
+
- - ">="
|
18
17
|
- !ruby/object:Gem::Version
|
19
|
-
version: '0.
|
20
|
-
|
18
|
+
version: '0.24'
|
19
|
+
type: :runtime
|
20
|
+
prerelease: false
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
22
|
+
requirements:
|
23
|
+
- - ">="
|
24
|
+
- !ruby/object:Gem::Version
|
25
|
+
version: '0.24'
|
26
|
+
- !ruby/object:Gem::Dependency
|
27
|
+
name: rumale-model_selection
|
28
|
+
requirement: !ruby/object:Gem::Requirement
|
29
|
+
requirements:
|
30
|
+
- - ">="
|
21
31
|
- !ruby/object:Gem::Version
|
22
32
|
version: '0.24'
|
23
33
|
type: :runtime
|
24
34
|
prerelease: false
|
25
35
|
version_requirements: !ruby/object:Gem::Requirement
|
26
36
|
requirements:
|
27
|
-
- - "
|
37
|
+
- - ">="
|
38
|
+
- !ruby/object:Gem::Version
|
39
|
+
version: '0.24'
|
40
|
+
- !ruby/object:Gem::Dependency
|
41
|
+
name: rumale-preprocessing
|
42
|
+
requirement: !ruby/object:Gem::Requirement
|
43
|
+
requirements:
|
44
|
+
- - ">="
|
28
45
|
- !ruby/object:Gem::Version
|
29
|
-
version: '0.
|
30
|
-
|
46
|
+
version: '0.24'
|
47
|
+
type: :runtime
|
48
|
+
prerelease: false
|
49
|
+
version_requirements: !ruby/object:Gem::Requirement
|
50
|
+
requirements:
|
51
|
+
- - ">="
|
31
52
|
- !ruby/object:Gem::Version
|
32
53
|
version: '0.24'
|
33
54
|
- !ruby/object:Gem::Dependency
|
@@ -70,7 +91,6 @@ metadata:
|
|
70
91
|
changelog_uri: https://github.com/yoshoku/rumale-torch/blob/main/CHANGELOG.md
|
71
92
|
documentation_uri: https://yoshoku.github.io/rumale-torch/doc/
|
72
93
|
rubygems_mfa_required: 'true'
|
73
|
-
post_install_message:
|
74
94
|
rdoc_options: []
|
75
95
|
require_paths:
|
76
96
|
- lib
|
@@ -85,8 +105,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
105
|
- !ruby/object:Gem::Version
|
86
106
|
version: '0'
|
87
107
|
requirements: []
|
88
|
-
rubygems_version: 3.2
|
89
|
-
signing_key:
|
108
|
+
rubygems_version: 3.6.2
|
90
109
|
specification_version: 4
|
91
110
|
summary: Rumale::Torch provides the learning and inference by the neural network defined
|
92
111
|
in torch.rb with the same interface as Rumale
|