rumale-torch 0.1.1 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 32ee2cbc5870b7921c3e6e1aeb9c9bdc045633c71c643d1f8c40f9208b5394cb
4
- data.tar.gz: 0d6d06dc1a320304a5ba69bd54d70cb5983f557df83177df03727aae95e5dd3d
3
+ metadata.gz: 42a236e4bcb36616d0ae30c45848786a6765c1c143ccd33c6b1674a960eabedb
4
+ data.tar.gz: 4cb6612a18e0d80027dce99feeb60e1af10eda8fde28d53570e258ce86031ecc
5
5
  SHA512:
6
- metadata.gz: ce42427042ccea12d20c145bee5ebc4b1a1f56e307b8cdbb4aa6b560b36208b178086a1dbf2f5818dc69ce17c6b57bf97db1d95c21d8ab3d0d8a0e55f6d43a4f
7
- data.tar.gz: 6d5841581a59a7c8cf57be45afdd0c174bcb938c1a85400d8fe35c1d315837fbdadf12b71155f72d7bfcb8e01ffdccc184fe78ff897ab4b6d78b660c199767f8
6
+ metadata.gz: 5752fcc6806297cbc3cf2157414aba1c9d7c284a5cecb86b447419a01c0122acf1a44341099505bfa5f1cf676489f6a8bb541acdad542b2739867b0b09a6315a
7
+ data.tar.gz: 13a317eac5078834e3a2d046561ca6b1cde0ab5cd8c802dec0f070aa47f2538dae3b970a48753fd9b3271228a9fb5dcdf8b7300c07f119e599a0a7be0598e5b6
data/CHANGELOG.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.2.0
2
+ - Refactor to support the new Rumale API.
3
+
1
4
  # 0.1.1
2
5
  - Refator codes and config files.
3
6
  - Introduce commitlint and rubocop.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2020-2022 Atsushi Tatsuma
1
+ Copyright (c) 2020-2023 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,7 +1,6 @@
1
1
  # Rumale::Torch
2
2
 
3
3
  [![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
4
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
5
4
  [![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
6
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
7
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/classifier'
5
5
  require 'rumale/preprocessing/label_encoder'
6
6
  require 'rumale/model_selection/function'
@@ -38,9 +38,8 @@ module Rumale
38
38
  #
39
39
  # classifier.predict(x)
40
40
  #
41
- class NeuralNetClassifier
42
- include Base::BaseEstimator
43
- include Base::Classifier
41
+ class NeuralNetClassifier < Rumale::Base::Estimator
42
+ include Rumale::Base::Classifier
44
43
 
45
44
  # Return the class labels.
46
45
  # @return [Numo::Int32] (size: n_classes)
@@ -80,6 +79,7 @@ module Rumale
80
79
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
81
80
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
82
81
  verbose: false, random_seed: nil)
82
+ super()
83
83
  @model = model
84
84
  @device = device || ::Torch.device('cpu')
85
85
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/regressor'
5
5
  require 'rumale/model_selection/shuffle_split'
6
6
 
@@ -36,9 +36,8 @@ module Rumale
36
36
  #
37
37
  # regressor.predict(x)
38
38
  #
39
- class NeuralNetRegressor
40
- include Base::BaseEstimator
41
- include Base::Regressor
39
+ class NeuralNetRegressor < Rumale::Base::Estimator
40
+ include Rumale::Base::Regressor
42
41
 
43
42
  # Return the neural nets defined with torch.rb.
44
43
  # @return [Torch::NN::Module]
@@ -74,6 +73,7 @@ module Rumale
74
73
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
75
74
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
76
75
  verbose: false, random_seed: nil)
76
+ super()
77
77
  @model = model
78
78
  @device = device || ::Torch.device('cpu')
79
79
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -6,6 +6,6 @@ module Rumale
6
6
  # the neural nets defined in torch.rb with the same interface as Rumale.
7
7
  module Torch
8
8
  # The version of Rumale::Torch you are using.
9
- VERSION = '0.1.1'
9
+ VERSION = '0.2.0'
10
10
  end
11
11
  end
metadata CHANGED
@@ -1,23 +1,34 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-torch
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-05-13 00:00:00.000000000 Z
11
+ date: 2023-01-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
- name: rumale
14
+ name: rumale-core
15
15
  requirement: !ruby/object:Gem::Requirement
16
16
  requirements:
17
17
  - - "~>"
18
18
  - !ruby/object:Gem::Version
19
- version: '0.14'
20
- - - "<"
19
+ version: '0.24'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.24'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-model_selection
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
21
32
  - !ruby/object:Gem::Version
22
33
  version: '0.24'
23
34
  type: :runtime
@@ -26,8 +37,19 @@ dependencies:
26
37
  requirements:
27
38
  - - "~>"
28
39
  - !ruby/object:Gem::Version
29
- version: '0.14'
30
- - - "<"
40
+ version: '0.24'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale-preprocessing
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '0.24'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
31
53
  - !ruby/object:Gem::Version
32
54
  version: '0.24'
33
55
  - !ruby/object:Gem::Dependency
@@ -85,7 +107,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
107
  - !ruby/object:Gem::Version
86
108
  version: '0'
87
109
  requirements: []
88
- rubygems_version: 3.2.33
110
+ rubygems_version: 3.3.26
89
111
  signing_key:
90
112
  specification_version: 4
91
113
  summary: Rumale::Torch provides the learning and inference by the neural network defined