rumale-torch 0.1.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 32ee2cbc5870b7921c3e6e1aeb9c9bdc045633c71c643d1f8c40f9208b5394cb
4
- data.tar.gz: 0d6d06dc1a320304a5ba69bd54d70cb5983f557df83177df03727aae95e5dd3d
3
+ metadata.gz: 3e192fe0e764bd3fb207eb742f9017b38c09d1f08ded26b7f40d742aa5e4e70f
4
+ data.tar.gz: 40b0f1a4b7e8f6528195e2399948607059582b15d5198352a731bfab2352ea97
5
5
  SHA512:
6
- metadata.gz: ce42427042ccea12d20c145bee5ebc4b1a1f56e307b8cdbb4aa6b560b36208b178086a1dbf2f5818dc69ce17c6b57bf97db1d95c21d8ab3d0d8a0e55f6d43a4f
7
- data.tar.gz: 6d5841581a59a7c8cf57be45afdd0c174bcb938c1a85400d8fe35c1d315837fbdadf12b71155f72d7bfcb8e01ffdccc184fe78ff897ab4b6d78b660c199767f8
6
+ metadata.gz: 31471cac43e90befc16e535da721db81d1626a5e0798704117365fed579e99a177725f5fe03bb115b4949e9bff71c7d682f4db5be35dbcc214b6f6587e47d5b8
7
+ data.tar.gz: 1498f17f305ec390664f140df4bb97dbca1cc3a690948229d34d9b45d4e160b436fe38354b6bf50c9deac25a43c80240dd576efbbf1f1956310cf39bfc8464d7
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ # 0.3.0
2
+ - Change the version specification of rumale-core gems.
3
+
4
+ # 0.2.0
5
+ - Refactor to support the new Rumale API.
6
+
1
7
  # 0.1.1
2
8
  - Refator codes and config files.
3
9
  - Introduce commitlint and rubocop.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2020-2022 Atsushi Tatsuma
1
+ Copyright (c) 2020-2025 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,7 +1,6 @@
1
1
  # Rumale::Torch
2
2
 
3
3
  [![Build Status](https://github.com/yoshoku/rumale-torch/workflows/build/badge.svg)](https://github.com/yoshoku/rumale-torch/actions?query=workflow%3Abuild)
4
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale-torch/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale-torch?branch=main)
5
4
  [![Gem Version](https://badge.fury.io/rb/rumale-torch.svg)](https://badge.fury.io/rb/rumale-torch)
6
5
  [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale-torch/blob/main/LICENSE.txt)
7
6
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale-torch/doc/)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/classifier'
5
5
  require 'rumale/preprocessing/label_encoder'
6
6
  require 'rumale/model_selection/function'
@@ -38,9 +38,8 @@ module Rumale
38
38
  #
39
39
  # classifier.predict(x)
40
40
  #
41
- class NeuralNetClassifier
42
- include Base::BaseEstimator
43
- include Base::Classifier
41
+ class NeuralNetClassifier < Rumale::Base::Estimator
42
+ include Rumale::Base::Classifier
44
43
 
45
44
  # Return the class labels.
46
45
  # @return [Numo::Int32] (size: n_classes)
@@ -80,6 +79,7 @@ module Rumale
80
79
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
81
80
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
82
81
  verbose: false, random_seed: nil)
82
+ super()
83
83
  @model = model
84
84
  @device = device || ::Torch.device('cpu')
85
85
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/base/base_estimator'
3
+ require 'rumale/base/estimator'
4
4
  require 'rumale/base/regressor'
5
5
  require 'rumale/model_selection/shuffle_split'
6
6
 
@@ -36,9 +36,8 @@ module Rumale
36
36
  #
37
37
  # regressor.predict(x)
38
38
  #
39
- class NeuralNetRegressor
40
- include Base::BaseEstimator
41
- include Base::Regressor
39
+ class NeuralNetRegressor < Rumale::Base::Estimator
40
+ include Rumale::Base::Regressor
42
41
 
43
42
  # Return the neural nets defined with torch.rb.
44
43
  # @return [Torch::NN::Module]
@@ -74,6 +73,7 @@ module Rumale
74
73
  def initialize(model:, device: nil, optimizer: nil, loss: nil,
75
74
  batch_size: 128, max_epoch: 10, shuffle: true, validation_split: 0,
76
75
  verbose: false, random_seed: nil)
76
+ super()
77
77
  @model = model
78
78
  @device = device || ::Torch.device('cpu')
79
79
  @optimizer = optimizer || ::Torch::Optim::Adam.new(model.parameters)
@@ -179,12 +179,14 @@ module Rumale
179
179
  end
180
180
 
181
181
  def display_epoch(train_loader, test_loader, epoch)
182
+ # rubocop:disable Lint/FormatParameterMismatch
182
183
  if test_loader.nil?
183
184
  puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f", epoch, evaluate(train_loader)))
184
185
  else
185
186
  puts(format("epoch: %#{max_epoch.to_s.length}d/#{max_epoch} - loss: %.4f - val_loss: %.4f",
186
187
  epoch, evaluate(train_loader), evaluate(test_loader)))
187
188
  end
189
+ # rubocop:enable Lint/FormatParameterMismatch
188
190
  end
189
191
 
190
192
  def evaluate(data_loader)
@@ -6,6 +6,6 @@ module Rumale
6
6
  # the neural nets defined in torch.rb with the same interface as Rumale.
7
7
  module Torch
8
8
  # The version of Rumale::Torch you are using.
9
- VERSION = '0.1.1'
9
+ VERSION = '0.3.0'
10
10
  end
11
11
  end
metadata CHANGED
@@ -1,33 +1,54 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-torch
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
- autorequire:
9
8
  bindir: exe
10
9
  cert_chain: []
11
- date: 2022-05-13 00:00:00.000000000 Z
10
+ date: 2025-01-02 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
- name: rumale
13
+ name: rumale-core
15
14
  requirement: !ruby/object:Gem::Requirement
16
15
  requirements:
17
- - - "~>"
16
+ - - ">="
18
17
  - !ruby/object:Gem::Version
19
- version: '0.14'
20
- - - "<"
18
+ version: '0.24'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0.24'
26
+ - !ruby/object:Gem::Dependency
27
+ name: rumale-model_selection
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - ">="
21
31
  - !ruby/object:Gem::Version
22
32
  version: '0.24'
23
33
  type: :runtime
24
34
  prerelease: false
25
35
  version_requirements: !ruby/object:Gem::Requirement
26
36
  requirements:
27
- - - "~>"
37
+ - - ">="
38
+ - !ruby/object:Gem::Version
39
+ version: '0.24'
40
+ - !ruby/object:Gem::Dependency
41
+ name: rumale-preprocessing
42
+ requirement: !ruby/object:Gem::Requirement
43
+ requirements:
44
+ - - ">="
28
45
  - !ruby/object:Gem::Version
29
- version: '0.14'
30
- - - "<"
46
+ version: '0.24'
47
+ type: :runtime
48
+ prerelease: false
49
+ version_requirements: !ruby/object:Gem::Requirement
50
+ requirements:
51
+ - - ">="
31
52
  - !ruby/object:Gem::Version
32
53
  version: '0.24'
33
54
  - !ruby/object:Gem::Dependency
@@ -70,7 +91,6 @@ metadata:
70
91
  changelog_uri: https://github.com/yoshoku/rumale-torch/blob/main/CHANGELOG.md
71
92
  documentation_uri: https://yoshoku.github.io/rumale-torch/doc/
72
93
  rubygems_mfa_required: 'true'
73
- post_install_message:
74
94
  rdoc_options: []
75
95
  require_paths:
76
96
  - lib
@@ -85,8 +105,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
105
  - !ruby/object:Gem::Version
86
106
  version: '0'
87
107
  requirements: []
88
- rubygems_version: 3.2.33
89
- signing_key:
108
+ rubygems_version: 3.6.2
90
109
  specification_version: 4
91
110
  summary: Rumale::Torch provides the learning and inference by the neural network defined
92
111
  in torch.rb with the same interface as Rumale