rumale-torch 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 32ee2cbc5870b7921c3e6e1aeb9c9bdc045633c71c643d1f8c40f9208b5394cb
4
- data.tar.gz: 0d6d06dc1a320304a5ba69bd54d70cb5983f557df83177df03727aae95e5dd3d
3
+ metadata.gz: 42a236e4bcb36616d0ae30c45848786a6765c1c143ccd33c6b1674a960eabedb
4
+ data.tar.gz: 4cb6612a18e0d80027dce99feeb60e1af10eda8fde28d53570e258ce86031ecc
5
5
  SHA512:
6
- metadata.gz: ce42427042ccea12d20c145bee5ebc4b1a1f56e307b8cdbb4aa6b560b36208b178086a1dbf2f5818dc69ce17c6b57bf97db1d95c21d8ab3d0d8a0e55f6d43a4f
7
- data.tar.gz: 6d5841581a59a7c8cf57be45afdd0c174bcb938c1a85400d8fe35c1d315837fbdadf12b71155f72d7bfcb8e01ffdccc184fe78ff897ab4b6d78b660c199767f8
6
+ metadata.gz: 5752fcc6806297cbc3cf2157414aba1c9d7c284a5cecb86b447419a01c0122acf1a44341099505bfa5f1cf676489f6a8bb541acdad542b2739867b0b09a6315a
7
+ data.tar.gz: 13a317eac5078834e3a2d046561ca6b1cde0ab5cd8c802dec0f070aa47f2538dae3b970a48753fd9b3271228a9fb5dcdf8b7300c07f119e599a0a7be0598e5b6
data/CHANGELOG.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.2.0
2
+ - Refactor to support the new Rumale API.
3
+
1
4
  # 0.1.1
2
5
  - Refator codes and config files.
3
6
  - Introduce commitlint and rubocop.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2020-2022 Atsushi Tatsuma
1
+ Copyright (c) 2020-2023 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,7 +1,6 @@
1
1
  # Rumale::Torch
2
2
 
3
3
  [![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
4
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
5
4
  [![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
6
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
7
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/classifier'
5
5
  require 'rumale/preprocessing/label_encoder'
6
6
  require 'rumale/model_selection/function'
@@ -38,9 +38,8 @@ module Rumale
38
38
  #
39
39
  # classifier.predict(x)
40
40
  #
41
- class NeuralNetClassifier
42
- include Base::BaseEstimator
43
- include Base::Classifier
41
+ class NeuralNetClassifier < Rumale::Base::Estimator
42
+ include Rumale::Base::Classifier
44
43
 
45
44
  # Return the class labels.
46
45
  # @return [Numo::Int32] (size: n_classes)
@@ -80,6 +79,7 @@ module Rumale
80
79
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
81
80
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
82
81
  verbose: false, random_seed: nil)
82
+ super()
83
83
  @model = model
84
84
  @device = device || ::Torch.device('cpu')
85
85
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/regressor'
5
5
  require 'rumale/model_selection/shuffle_split'
6
6
 
@@ -36,9 +36,8 @@ module Rumale
36
36
  #
37
37
  # regressor.predict(x)
38
38
  #
39
- class NeuralNetRegressor
40
- include Base::BaseEstimator
41
- include Base::Regressor
39
+ class NeuralNetRegressor < Rumale::Base::Estimator
40
+ include Rumale::Base::Regressor
42
41
 
43
42
  # Return the neural nets defined with torch.rb.
44
43
  # @return [Torch::NN::Module]
@@ -74,6 +73,7 @@ module Rumale
74
73
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
75
74
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
76
75
  verbose: false, random_seed: nil)
76
+ super()
77
77
  @model = model
78
78
  @device = device || ::Torch.device('cpu')
79
79
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -6,6 +6,6 @@ module Rumale
6
6
  # the neural nets defined in torch.rb with the same interface as Rumale.
7
7
  module Torch
8
8
  # The version of Rumale::Torch you are using.
9
- VERSION = '0.1.1'
9
+ VERSION = '0.2.0'
10
10
  end
11
11
  end
metadata CHANGED
@@ -1,23 +1,34 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-torch
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-05-13 00:00:00.000000000 Z
11
+ date: 2023-01-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
- name: rumale
14
+ name: rumale-core
15
15
  requirement: !ruby/object:Gem::Requirement
16
16
  requirements:
17
17
  - - "~>"
18
18
  - !ruby/object:Gem::Version
19
- version: '0.14'
20
- - - "<"
19
+ version: '0.24'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.24'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-model_selection
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
21
32
  - !ruby/object:Gem::Version
22
33
  version: '0.24'
23
34
  type: :runtime
@@ -26,8 +37,19 @@ dependencies:
26
37
  requirements:
27
38
  - - "~>"
28
39
  - !ruby/object:Gem::Version
29
- version: '0.14'
30
- - - "<"
40
+ version: '0.24'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rumale-preprocessing
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '0.24'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
31
53
  - !ruby/object:Gem::Version
32
54
  version: '0.24'
33
55
  - !ruby/object:Gem::Dependency
@@ -85,7 +107,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
107
  - !ruby/object:Gem::Version
86
108
  version: '0'
87
109
  requirements: []
88
- rubygems_version: 3.2.33
110
+ rubygems_version: 3.3.26
89
111
  signing_key:
90
112
  specification_version: 4
91
113
  summary: Rumale::Torch provides the learning and inference by the neural network defined