rumale-torch 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +1 -1
- data/README.md +0 -1
- data/lib/rumale/torch/neural_net_classifier.rb +4 -4
- data/lib/rumale/torch/neural_net_regressor.rb +4 -4
- data/lib/rumale/torch/version.rb +1 -1
- metadata +30 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 42a236e4bcb36616d0ae30c45848786a6765c1c143ccd33c6b1674a960eabedb
|
4
|
+
data.tar.gz: 4cb6612a18e0d80027dce99feeb60e1af10eda8fde28d53570e258ce86031ecc
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5752fcc6806297cbc3cf2157414aba1c9d7c284a5cecb86b447419a01c0122acf1a44341099505bfa5f1cf676489f6a8bb541acdad542b2739867b0b09a6315a
|
7
|
+
data.tar.gz: 13a317eac5078834e3a2d046561ca6b1cde0ab5cd8c802dec0f070aa47f2538dae3b970a48753fd9b3271228a9fb5dcdf8b7300c07f119e599a0a7be0598e5b6
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
# Rumale::Torch
|
2
2
|
|
3
3
|
[](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
|
4
|
-
[](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
|
5
4
|
[](https://badge.fury.io/rb/rumale-torch)
|
6
5
|
[](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
|
7
6
|
[](https://yoshoku.github.io/rumale-torch/doc/)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/classifier'
|
5
5
|
require 'rumale/preprocessing/label_encoder'
|
6
6
|
require 'rumale/model_selection/function'
|
@@ -38,9 +38,8 @@ module Rumale
|
|
38
38
|
#
|
39
39
|
# classifier.predict(x)
|
40
40
|
#
|
41
|
-
class NeuralNetClassifier
|
42
|
-
include Base::
|
43
|
-
include Base::Classifier
|
41
|
+
class NeuralNetClassifier < Rumale::Base::Estimator
|
42
|
+
include Rumale::Base::Classifier
|
44
43
|
|
45
44
|
# Return the class labels.
|
46
45
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -80,6 +79,7 @@ module Rumale
|
|
80
79
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
81
80
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
82
81
|
verbose: false, random_seed: nil)
|
82
|
+
super()
|
83
83
|
@model = model
|
84
84
|
@device = device || ::Torch.device('cpu')
|
85
85
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/regressor'
|
5
5
|
require 'rumale/model_selection/shuffle_split'
|
6
6
|
|
@@ -36,9 +36,8 @@ module Rumale
|
|
36
36
|
#
|
37
37
|
# regressor.predict(x)
|
38
38
|
#
|
39
|
-
class NeuralNetRegressor
|
40
|
-
include Base::
|
41
|
-
include Base::Regressor
|
39
|
+
class NeuralNetRegressor < Rumale::Base::Estimator
|
40
|
+
include Rumale::Base::Regressor
|
42
41
|
|
43
42
|
# Return the neural nets defined with torch.rb.
|
44
43
|
# @return [Torch::NN::Module]
|
@@ -74,6 +73,7 @@ module Rumale
|
|
74
73
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
75
74
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
76
75
|
verbose: false, random_seed: nil)
|
76
|
+
super()
|
77
77
|
@model = model
|
78
78
|
@device = device || ::Torch.device('cpu')
|
79
79
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
data/lib/rumale/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,23 +1,34 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-torch
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
|
-
name: rumale
|
14
|
+
name: rumale-core
|
15
15
|
requirement: !ruby/object:Gem::Requirement
|
16
16
|
requirements:
|
17
17
|
- - "~>"
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '0.
|
20
|
-
|
19
|
+
version: '0.24'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0.24'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rumale-model_selection
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
21
32
|
- !ruby/object:Gem::Version
|
22
33
|
version: '0.24'
|
23
34
|
type: :runtime
|
@@ -26,8 +37,19 @@ dependencies:
|
|
26
37
|
requirements:
|
27
38
|
- - "~>"
|
28
39
|
- !ruby/object:Gem::Version
|
29
|
-
version: '0.
|
30
|
-
|
40
|
+
version: '0.24'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rumale-preprocessing
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0.24'
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
31
53
|
- !ruby/object:Gem::Version
|
32
54
|
version: '0.24'
|
33
55
|
- !ruby/object:Gem::Dependency
|
@@ -85,7 +107,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
107
|
- !ruby/object:Gem::Version
|
86
108
|
version: '0'
|
87
109
|
requirements: []
|
88
|
-
rubygems_version: 3.
|
110
|
+
rubygems_version: 3.3.26
|
89
111
|
signing_key:
|
90
112
|
specification_version: 4
|
91
113
|
summary: Rumale::Torch provides the learning and inference by the neural network defined
|