rumale 0.20.1 → 0.20.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 0f361026cd2922a2d36846a817eee855bf0c000156ed6c756bca29d2e42d67a2
4
- data.tar.gz: 016fa40aa2546824cacbc32353263cbfc9427f0ceabb7e703f99854914bb9a2e
3
+ metadata.gz: 5d8c93acbf38fbd07e5df224010abbdd4269a6ce3bbf8112a0eba652a606785d
4
+ data.tar.gz: e7cb00a802420854835c92f011425f3054bfcc1052bf7b3664da1f95834ef435
5
5
  SHA512:
6
- metadata.gz: 7a53a958db7ec8b56236018505370b9908ae81a9afc9d7c8ff0b16d83971539c1ad729b5ab350eb49ae9b90ada43a8912ed2404a37eef97a4d34dad90b1d3e9f
7
- data.tar.gz: 2f2b3d48625c7120464179bc7759c01ba7de85cb0d54720665eaf1e4822f24c1870474ebc24a47cff123e44a8626b0e0fac6a7e81216c057286071770ea5ba79
6
+ metadata.gz: f95fdd89b84dad02e516ee0479b1cddfb101cb96de897b6e7fa3fba546272a243cff5cfe954cb51942ec1ab23cf3028b183db86b52fab00a35d15be7eee5bf92
7
+ data.tar.gz: e5f6235e88dd47b9002a2154cabd2c1e64afb6cbb5b0745b411c7e5559351e925c9db8ec332724e301b83215662b3582e79a9e997f0338846514b234dabf1fc3
@@ -1,3 +1,7 @@
1
+ # 0.20.2
2
+ - Add cross-validator class for time-series data.
3
+ - [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
4
+
1
5
  # 0.20.1
2
6
  - Add cross-validator classes that split data according group labels.
3
7
  - [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
@@ -103,6 +103,7 @@ require 'rumale/model_selection/stratified_k_fold'
103
103
  require 'rumale/model_selection/shuffle_split'
104
104
  require 'rumale/model_selection/group_shuffle_split'
105
105
  require 'rumale/model_selection/stratified_shuffle_split'
106
+ require 'rumale/model_selection/time_series_split'
106
107
  require 'rumale/model_selection/cross_validation'
107
108
  require 'rumale/model_selection/grid_search_cv'
108
109
  require 'rumale/model_selection/function'
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
8
+ # It is assumed that the dataset given are already ordered by time information.
9
+ #
10
+ # @example
11
+ # cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
12
+ # x = Numo::DFloat.new(6, 2).rand
13
+ # cv.split(x, nil).each do |train_ids, test_ids|
14
+ # puts '---'
15
+ # pp train_ids
16
+ # pp test_ids
17
+ # end
18
+ #
19
+ # # ---
20
+ # # [0]
21
+ # # [1]
22
+ # # ---
23
+ # # [0, 1]
24
+ # # [2]
25
+ # # ---
26
+ # # [0, 1, 2]
27
+ # # [3]
28
+ # # ---
29
+ # # [0, 1, 2, 3]
30
+ # # [4]
31
+ # # ---
32
+ # # [0, 1, 2, 3, 4]
33
+ # # [5]
34
+ #
35
+ class TimeSeriesSplit
36
+ include Base::Splitter
37
+
38
+ # Return the number of splits.
39
+ # @return [Integer]
40
+ attr_reader :n_splits
41
+
42
+ # Return the maximum number of training samples in a split.
43
+ # @return [Integer/Nil]
44
+ attr_reader :max_train_size
45
+
46
+ # Create a new data splitter for time series cross-validation.
47
+ #
48
+ # @param n_splits [Integer] The number of splits.
49
+ # @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
50
+ def initialize(n_splits: 5, max_train_size: nil)
51
+ check_params_numeric(n_splits: n_splits)
52
+ check_params_numeric_or_nil(max_train_size: max_train_size)
53
+ @n_splits = n_splits
54
+ @max_train_size = max_train_size
55
+ end
56
+
57
+ # Generate data indices for time series cross-validation.
58
+ #
59
+ # @overload split(x, y) -> Array
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
61
+ # The dataset to be used to generate data indices for time series cross-validation.
62
+ # It is expected that the data will be ordered by time information.
63
+ # @param y [Numo::Int32] (shape: [n_samples])
64
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
65
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
66
+ def split(x, _y)
67
+ x = check_convert_sample_array(x)
68
+
69
+ n_samples = x.shape[0]
70
+ unless (@n_splits + 1).between?(2, n_samples)
71
+ raise ArgumentError,
72
+ 'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
73
+ end
74
+
75
+ test_size = n_samples / (@n_splits + 1)
76
+ offset = test_size + n_samples % (@n_splits + 1)
77
+
78
+ Array.new(@n_splits) do |n|
79
+ start = offset * (n + 1)
80
+ train_ids = if !@max_train_size.nil? && @max_train_size < test_size
81
+ Array((start - @max_train_size)...start)
82
+ else
83
+ Array(0...start)
84
+ end
85
+ test_ids = Array(start...(start + test_size))
86
+ [train_ids, test_ids]
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.20.1'
6
+ VERSION = '0.20.2'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.20.1
4
+ version: 0.20.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2020-08-23 00:00:00.000000000 Z
11
+ date: 2020-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -141,6 +141,7 @@ files:
141
141
  - lib/rumale/model_selection/shuffle_split.rb
142
142
  - lib/rumale/model_selection/stratified_k_fold.rb
143
143
  - lib/rumale/model_selection/stratified_shuffle_split.rb
144
+ - lib/rumale/model_selection/time_series_split.rb
144
145
  - lib/rumale/multiclass/one_vs_rest_classifier.rb
145
146
  - lib/rumale/naive_bayes/base_naive_bayes.rb
146
147
  - lib/rumale/naive_bayes/bernoulli_nb.rb