rumale-torch 0.1.1 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +1 -1
- data/README.md +0 -1
- data/lib/rumale/torch/neural_net_classifier.rb +4 -4
- data/lib/rumale/torch/neural_net_regressor.rb +4 -4
- data/lib/rumale/torch/version.rb +1 -1
- metadata +30 -8
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 42a236e4bcb36616d0ae30c45848786a6765c1c143ccd33c6b1674a960eabedb
|
4
|
+
data.tar.gz: 4cb6612a18e0d80027dce99feeb60e1af10eda8fde28d53570e258ce86031ecc
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5752fcc6806297cbc3cf2157414aba1c9d7c284a5cecb86b447419a01c0122acf1a44341099505bfa5f1cf676489f6a8bb541acdad542b2739867b0b09a6315a
|
7
|
+
data.tar.gz: 13a317eac5078834e3a2d046561ca6b1cde0ab5cd8c802dec0f070aa47f2538dae3b970a48753fd9b3271228a9fb5dcdf8b7300c07f119e599a0a7be0598e5b6
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
# Rumale::Torch
|
2
2
|
|
3
3
|
[![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
|
4
|
-
[![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
|
5
4
|
[![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
|
6
5
|
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
|
7
6
|
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/classifier'
|
5
5
|
require 'rumale/preprocessing/label_encoder'
|
6
6
|
require 'rumale/model_selection/function'
|
@@ -38,9 +38,8 @@ module Rumale
|
|
38
38
|
#
|
39
39
|
# classifier.predict(x)
|
40
40
|
#
|
41
|
-
class NeuralNetClassifier
|
42
|
-
include Base::
|
43
|
-
include Base::Classifier
|
41
|
+
class NeuralNetClassifier < Rumale::Base::Estimator
|
42
|
+
include Rumale::Base::Classifier
|
44
43
|
|
45
44
|
# Return the class labels.
|
46
45
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -80,6 +79,7 @@ module Rumale
|
|
80
79
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
81
80
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
82
81
|
verbose: false, random_seed: nil)
|
82
|
+
super()
|
83
83
|
@model = model
|
84
84
|
@device = device || ::Torch.device('cpu')
|
85
85
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/
|
3
|
+
require 'rumale/base/estimator'
|
4
4
|
require 'rumale/base/regressor'
|
5
5
|
require 'rumale/model_selection/shuffle_split'
|
6
6
|
|
@@ -36,9 +36,8 @@ module Rumale
|
|
36
36
|
#
|
37
37
|
# regressor.predict(x)
|
38
38
|
#
|
39
|
-
class NeuralNetRegressor
|
40
|
-
include Base::
|
41
|
-
include Base::Regressor
|
39
|
+
class NeuralNetRegressor < Rumale::Base::Estimator
|
40
|
+
include Rumale::Base::Regressor
|
42
41
|
|
43
42
|
# Return the neural nets defined with torch.rb.
|
44
43
|
# @return [Torch::NN::Module]
|
@@ -74,6 +73,7 @@ module Rumale
|
|
74
73
|
def initialize(model:, device: nil, optimizer: nil, loss: nil,
|
75
74
|
batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
|
76
75
|
verbose: false, random_seed: nil)
|
76
|
+
super()
|
77
77
|
@model = model
|
78
78
|
@device = device || ::Torch.device('cpu')
|
79
79
|
@optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
|
data/lib/rumale/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,23 +1,34 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-torch
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-01-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
|
-
name: rumale
|
14
|
+
name: rumale-core
|
15
15
|
requirement: !ruby/object:Gem::Requirement
|
16
16
|
requirements:
|
17
17
|
- - "~>"
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '0.
|
20
|
-
|
19
|
+
version: '0.24'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0.24'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rumale-model_selection
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
21
32
|
- !ruby/object:Gem::Version
|
22
33
|
version: '0.24'
|
23
34
|
type: :runtime
|
@@ -26,8 +37,19 @@ dependencies:
|
|
26
37
|
requirements:
|
27
38
|
- - "~>"
|
28
39
|
- !ruby/object:Gem::Version
|
29
|
-
version: '0.
|
30
|
-
|
40
|
+
version: '0.24'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rumale-preprocessing
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0.24'
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
31
53
|
- !ruby/object:Gem::Version
|
32
54
|
version: '0.24'
|
33
55
|
- !ruby/object:Gem::Dependency
|
@@ -85,7 +107,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
107
|
- !ruby/object:Gem::Version
|
86
108
|
version: '0'
|
87
109
|
requirements: []
|
88
|
-
rubygems_version: 3.
|
110
|
+
rubygems_version: 3.3.26
|
89
111
|
signing_key:
|
90
112
|
specification_version: 4
|
91
113
|
summary: Rumale::Torch provides the learning and inference by the neural network defined
|