rumale-torch 0.1.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 32ee2cbc5870b7921c3e6e1aeb9c9bdc045633c71c643d1f8c40f9208b5394cb
4
- data.tar.gz: 0d6d06dc1a320304a5ba69bd54d70cb5983f557df83177df03727aae95e5dd3d
3
+ metadata.gz: 3e192fe0e764bd3fb207eb742f9017b38c09d1f08ded26b7f40d742aa5e4e70f
4
+ data.tar.gz: 40b0f1a4b7e8f6528195e2399948607059582b15d5198352a731bfab2352ea97
5
5
  SHA512:
6
- metadata.gz: ce42427042ccea12d20c145bee5ebc4b1a1f56e307b8cdbb4aa6b560b36208b178086a1dbf2f5818dc69ce17c6b57bf97db1d95c21d8ab3d0d8a0e55f6d43a4f
7
- data.tar.gz: 6d5841581a59a7c8cf57be45afdd0c174bcb938c1a85400d8fe35c1d315837fbdadf12b71155f72d7bfcb8e01ffdccc184fe78ff897ab4b6d78b660c199767f8
6
+ metadata.gz: 31471cac43e90befc16e535da721db81d1626a5e0798704117365fed579e99a177725f5fe03bb115b4949e9bff71c7d682f4db5be35dbcc214b6f6587e47d5b8
7
+ data.tar.gz: 1498f17f305ec390664f140df4bb97dbca1cc3a690948229d34d9b45d4e160b436fe38354b6bf50c9deac25a43c80240dd576efbbf1f1956310cf39bfc8464d7
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ # 0.3.0
2
+ - Change the version specification of rumale-core gems.
3
+
4
+ # 0.2.0
5
+ - Refactor to support the new Rumale API.
6
+
1
7
  # 0.1.1
2
8
  - Refator codes and config files.
3
9
  - Introduce commitlint and rubocop.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2020-2022 Atsushi Tatsuma
1
+ Copyright (c) 2020-2025 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,7 +1,6 @@
1
1
  # Rumale::Torch
2
2
 
3
3
  [![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
4
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
5
4
  [![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
6
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
7
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/classifier'
5
5
  require 'rumale/preprocessing/label_encoder'
6
6
  require 'rumale/model_selection/function'
@@ -38,9 +38,8 @@ module Rumale
38
38
  #
39
39
  # classifier.predict(x)
40
40
  #
41
- class NeuralNetClassifier
42
- include Base::BaseEstimator
43
- include Base::Classifier
41
+ class NeuralNetClassifier < Rumale::Base::Estimator
42
+ include Rumale::Base::Classifier
44
43
 
45
44
  # Return the class labels.
46
45
  # @return [Numo::Int32] (size: n_classes)
@@ -80,6 +79,7 @@ module Rumale
80
79
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
81
80
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
82
81
  verbose: false, random_seed: nil)
82
+ super()
83
83
  @model = model
84
84
  @device = device || ::Torch.device('cpu')
85
85
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/regressor'
5
5
  require 'rumale/model_selection/shuffle_split'
6
6
 
@@ -36,9 +36,8 @@ module Rumale
36
36
  #
37
37
  # regressor.predict(x)
38
38
  #
39
- class NeuralNetRegressor
40
- include Base::BaseEstimator
41
- include Base::Regressor
39
+ class NeuralNetRegressor < Rumale::Base::Estimator
40
+ include Rumale::Base::Regressor
42
41
 
43
42
  # Return the neural nets defined with torch.rb.
44
43
  # @return [Torch::NN::Module]
@@ -74,6 +73,7 @@ module Rumale
74
73
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
75
74
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
76
75
  verbose: false, random_seed: nil)
76
+ super()
77
77
  @model = model
78
78
  @device = device || ::Torch.device('cpu')
79
79
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -179,12 +179,14 @@ module Rumale
179
179
  end
180
180
 
181
181
  def display_epoch(train_loader, test_loader, epoch)
182
+ # rubocop:disable Lint/FormatParameterMismatch
182
183
  if test_loader.nil?
183
184
  puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f", epoch, evaluate(train_loader)))
184
185
  else
185
186
  puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f - val_loss: %.4f",
186
187
  epoch, evaluate(train_loader), evaluate(test_loader)))
187
188
  end
189
+ # rubocop:enable Lint/FormatParameterMismatch
188
190
  end
189
191
 
190
192
  def evaluate(data_loader)
@@ -6,6 +6,6 @@ module Rumale
6
6
  # the neural nets defined in torch.rb with the same interface as Rumale.
7
7
  module Torch
8
8
  # The version of Rumale::Torch you are using.
9
- VERSION = '0.1.1'
9
+ VERSION = '0.3.0'
10
10
  end
11
11
  end
metadata CHANGED
@@ -1,33 +1,54 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-torch
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
9
8
  bindir: exe
10
9
  cert_chain: []
11
- date: 2022-05-13 00:00:00.000000000 Z
10
+ date: 2025-01-02 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
- name: rumale
13
+ name: rumale-core
15
14
  requirement: !ruby/object:Gem::Requirement
16
15
  requirements:
17
- - - "~>"
16
+ - - ">="
18
17
  - !ruby/object:Gem::Version
19
- version: '0.14'
20
- - - "<"
18
+ version: '0.24'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0.24'
26
+ - !ruby/object:Gem::Dependency
27
+ name: rumale-model_selection
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - ">="
21
31
  - !ruby/object:Gem::Version
22
32
  version: '0.24'
23
33
  type: :runtime
24
34
  prerelease: false
25
35
  version_requirements: !ruby/object:Gem::Requirement
26
36
  requirements:
27
- - - "~>"
37
+ - - ">="
38
+ - !ruby/object:Gem::Version
39
+ version: '0.24'
40
+ - !ruby/object:Gem::Dependency
41
+ name: rumale-preprocessing
42
+ requirement: !ruby/object:Gem::Requirement
43
+ requirements:
44
+ - - ">="
28
45
  - !ruby/object:Gem::Version
29
- version: '0.14'
30
- - - "<"
46
+ version: '0.24'
47
+ type: :runtime
48
+ prerelease: false
49
+ version_requirements: !ruby/object:Gem::Requirement
50
+ requirements:
51
+ - - ">="
31
52
  - !ruby/object:Gem::Version
32
53
  version: '0.24'
33
54
  - !ruby/object:Gem::Dependency
@@ -70,7 +91,6 @@ metadata:
70
91
  changelog_uri: https://github.com/yoshoku/rumale-torch/blob/main/CHANGELOG.md
71
92
  documentation_uri: https://yoshoku.github.io/rumale-torch/doc/
72
93
  rubygems_mfa_required: 'true'
73
- post_install_message:
74
94
  rdoc_options: []
75
95
  require_paths:
76
96
  - lib
@@ -85,8 +105,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
105
  - !ruby/object:Gem::Version
86
106
  version: '0'
87
107
  requirements: []
88
- rubygems_version: 3.2.33
89
- signing_key:
108
+ rubygems_version: 3.6.2
90
109
  specification_version: 4
91
110
  summary: Rumale::Torch provides the learning and inference by the neural network defined
92
111
  in torch.rb with the same interface as Rumale