rumale 0.20.1 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0f361026cd2922a2d36846a817eee855bf0c000156ed6c756bca29d2e42d67a2
4
- data.tar.gz: 016fa40aa2546824cacbc32353263cbfc9427f0ceabb7e703f99854914bb9a2e
3
+ metadata.gz: 5d8c93acbf38fbd07e5df224010abbdd4269a6ce3bbf8112a0eba652a606785d
4
+ data.tar.gz: e7cb00a802420854835c92f011425f3054bfcc1052bf7b3664da1f95834ef435
5
5
  SHA512:
6
- metadata.gz: 7a53a958db7ec8b56236018505370b9908ae81a9afc9d7c8ff0b16d83971539c1ad729b5ab350eb49ae9b90ada43a8912ed2404a37eef97a4d34dad90b1d3e9f
7
- data.tar.gz: 2f2b3d48625c7120464179bc7759c01ba7de85cb0d54720665eaf1e4822f24c1870474ebc24a47cff123e44a8626b0e0fac6a7e81216c057286071770ea5ba79
6
+ metadata.gz: f95fdd89b84dad02e516ee0479b1cddfb101cb96de897b6e7fa3fba546272a243cff5cfe954cb51942ec1ab23cf3028b183db86b52fab00a35d15be7eee5bf92
7
+ data.tar.gz: e5f6235e88dd47b9002a2154cabd2c1e64afb6cbb5b0745b411c7e5559351e925c9db8ec332724e301b83215662b3582e79a9e997f0338846514b234dabf1fc3
@@ -1,3 +1,7 @@
1
+ # 0.20.2
2
+ - Add cross-validator class for time-series data.
3
+ - [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
4
+
1
5
  # 0.20.1
2
6
  - Add cross-validator classes that split data according group labels.
3
7
  - [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
@@ -103,6 +103,7 @@ require 'rumale/model_selection/stratified_k_fold'
103
103
  require 'rumale/model_selection/shuffle_split'
104
104
  require 'rumale/model_selection/group_shuffle_split'
105
105
  require 'rumale/model_selection/stratified_shuffle_split'
106
+ require 'rumale/model_selection/time_series_split'
106
107
  require 'rumale/model_selection/cross_validation'
107
108
  require 'rumale/model_selection/grid_search_cv'
108
109
  require 'rumale/model_selection/function'
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
8
+ # It is assumed that the dataset given are already ordered by time information.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
12
+ # x = Numo::DFloat.new(6, 2).rand
13
+ # cv.split(x, nil).each do |train_ids, test_ids|
14
+ # puts '---'
15
+ # pp train_ids
16
+ # pp test_ids
17
+ # end
18
+ #
19
+ # # ---
20
+ # # [0]
21
+ # # [1]
22
+ # # ---
23
+ # # [0, 1]
24
+ # # [2]
25
+ # # ---
26
+ # # [0, 1, 2]
27
+ # # [3]
28
+ # # ---
29
+ # # [0, 1, 2, 3]
30
+ # # [4]
31
+ # # ---
32
+ # # [0, 1, 2, 3, 4]
33
+ # # [5]
34
+ #
35
+ class TimeSeriesSplit
36
+ include Base::Splitter
37
+
38
+ # Return the number of splits.
39
+ # @return [Integer]
40
+ attr_reader :n_splits
41
+
42
+ # Return the maximum number of training samples in a split.
43
+ # @return [Integer/Nil]
44
+ attr_reader :max_train_size
45
+
46
+ # Create a new data splitter for time series cross-validation.
47
+ #
48
+ # @param n_splits [Integer] The number of splits.
49
+ # @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
50
+ def initialize(n_splits: 5, max_train_size: nil)
51
+ check_params_numeric(n_splits: n_splits)
52
+ check_params_numeric_or_nil(max_train_size: max_train_size)
53
+ @n_splits = n_splits
54
+ @max_train_size = max_train_size
55
+ end
56
+
57
+ # Generate data indices for time series cross-validation.
58
+ #
59
+ # @overload split(x, y) -> Array
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
61
+ # The dataset to be used to generate data indices for time series cross-validation.
62
+ # It is expected that the data will be ordered by time information.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
66
+ def split(x, _y)
67
+ x = check_convert_sample_array(x)
68
+
69
+ n_samples = x.shape[0]
70
+ unless (@n_splits + 1).between?(2, n_samples)
71
+ raise ArgumentError,
72
+ 'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
73
+ end
74
+
75
+ test_size = n_samples / (@n_splits + 1)
76
+ offset = test_size + n_samples % (@n_splits + 1)
77
+
78
+ Array.new(@n_splits) do |n|
79
+ start = offset * (n + 1)
80
+ train_ids = if !@max_train_size.nil? && @max_train_size < test_size
81
+ Array((start - @max_train_size)...start)
82
+ else
83
+ Array(0...start)
84
+ end
85
+ test_ids = Array(start...(start + test_size))
86
+ [train_ids, test_ids]
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.20.1'
6
+ VERSION = '0.20.2'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.20.1
4
+ version: 0.20.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-08-23 00:00:00.000000000 Z
11
+ date: 2020-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -141,6 +141,7 @@ files:
141
141
  - lib/rumale/model_selection/shuffle_split.rb
142
142
  - lib/rumale/model_selection/stratified_k_fold.rb
143
143
  - lib/rumale/model_selection/stratified_shuffle_split.rb
144
+ - lib/rumale/model_selection/time_series_split.rb
144
145
  - lib/rumale/multiclass/one_vs_rest_classifier.rb
145
146
  - lib/rumale/naive_bayes/base_naive_bayes.rb
146
147
  - lib/rumale/naive_bayes/bernoulli_nb.rb