rumale 0.20.1 → 0.20.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/rumale.rb +1 -0
- data/lib/rumale/model_selection/time_series_split.rb +91 -0
- data/lib/rumale/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 5d8c93acbf38fbd07e5df224010abbdd4269a6ce3bbf8112a0eba652a606785d
|
4
|
+
data.tar.gz: e7cb00a802420854835c92f011425f3054bfcc1052bf7b3664da1f95834ef435
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f95fdd89b84dad02e516ee0479b1cddfb101cb96de897b6e7fa3fba546272a243cff5cfe954cb51942ec1ab23cf3028b183db86b52fab00a35d15be7eee5bf92
|
7
|
+
data.tar.gz: e5f6235e88dd47b9002a2154cabd2c1e64afb6cbb5b0745b411c7e5559351e925c9db8ec332724e301b83215662b3582e79a9e997f0338846514b234dabf1fc3
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,7 @@
|
|
1
|
+
# 0.20.2
|
2
|
+
- Add cross-validator class for time-series data.
|
3
|
+
- [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
|
4
|
+
|
1
5
|
# 0.20.1
|
2
6
|
- Add cross-validator classes that split data according group labels.
|
3
7
|
- [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
|
data/lib/rumale.rb
CHANGED
@@ -103,6 +103,7 @@ require 'rumale/model_selection/stratified_k_fold'
|
|
103
103
|
require 'rumale/model_selection/shuffle_split'
|
104
104
|
require 'rumale/model_selection/group_shuffle_split'
|
105
105
|
require 'rumale/model_selection/stratified_shuffle_split'
|
106
|
+
require 'rumale/model_selection/time_series_split'
|
106
107
|
require 'rumale/model_selection/cross_validation'
|
107
108
|
require 'rumale/model_selection/grid_search_cv'
|
108
109
|
require 'rumale/model_selection/function'
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/splitter'
|
4
|
+
|
5
|
+
module Rumale
|
6
|
+
module ModelSelection
|
7
|
+
# TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
|
8
|
+
# It is assumed that the dataset given are already ordered by time information.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
|
12
|
+
# x = Numo::DFloat.new(6, 2).rand
|
13
|
+
# cv.split(x, nil).each do |train_ids, test_ids|
|
14
|
+
# puts '---'
|
15
|
+
# pp train_ids
|
16
|
+
# pp test_ids
|
17
|
+
# end
|
18
|
+
#
|
19
|
+
# # ---
|
20
|
+
# # [0]
|
21
|
+
# # [1]
|
22
|
+
# # ---
|
23
|
+
# # [0, 1]
|
24
|
+
# # [2]
|
25
|
+
# # ---
|
26
|
+
# # [0, 1, 2]
|
27
|
+
# # [3]
|
28
|
+
# # ---
|
29
|
+
# # [0, 1, 2, 3]
|
30
|
+
# # [4]
|
31
|
+
# # ---
|
32
|
+
# # [0, 1, 2, 3, 4]
|
33
|
+
# # [5]
|
34
|
+
#
|
35
|
+
class TimeSeriesSplit
|
36
|
+
include Base::Splitter
|
37
|
+
|
38
|
+
# Return the number of splits.
|
39
|
+
# @return [Integer]
|
40
|
+
attr_reader :n_splits
|
41
|
+
|
42
|
+
# Return the maximum number of training samples in a split.
|
43
|
+
# @return [Integer/Nil]
|
44
|
+
attr_reader :max_train_size
|
45
|
+
|
46
|
+
# Create a new data splitter for time series cross-validation.
|
47
|
+
#
|
48
|
+
# @param n_splits [Integer] The number of splits.
|
49
|
+
# @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
|
50
|
+
def initialize(n_splits: 5, max_train_size: nil)
|
51
|
+
check_params_numeric(n_splits: n_splits)
|
52
|
+
check_params_numeric_or_nil(max_train_size: max_train_size)
|
53
|
+
@n_splits = n_splits
|
54
|
+
@max_train_size = max_train_size
|
55
|
+
end
|
56
|
+
|
57
|
+
# Generate data indices for time series cross-validation.
|
58
|
+
#
|
59
|
+
# @overload split(x, y) -> Array
|
60
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
61
|
+
# The dataset to be used to generate data indices for time series cross-validation.
|
62
|
+
# It is expected that the data will be ordered by time information.
|
63
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
64
|
+
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
65
|
+
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
66
|
+
def split(x, _y)
|
67
|
+
x = check_convert_sample_array(x)
|
68
|
+
|
69
|
+
n_samples = x.shape[0]
|
70
|
+
unless (@n_splits + 1).between?(2, n_samples)
|
71
|
+
raise ArgumentError,
|
72
|
+
'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
|
73
|
+
end
|
74
|
+
|
75
|
+
test_size = n_samples / (@n_splits + 1)
|
76
|
+
offset = test_size + n_samples % (@n_splits + 1)
|
77
|
+
|
78
|
+
Array.new(@n_splits) do |n|
|
79
|
+
start = offset * (n + 1)
|
80
|
+
train_ids = if !@max_train_size.nil? && @max_train_size < test_size
|
81
|
+
Array((start - @max_train_size)...start)
|
82
|
+
else
|
83
|
+
Array(0...start)
|
84
|
+
end
|
85
|
+
test_ids = Array(start...(start + test_size))
|
86
|
+
[train_ids, test_ids]
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
data/lib/rumale/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.20.
|
4
|
+
version: 0.20.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-09-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -141,6 +141,7 @@ files:
|
|
141
141
|
- lib/rumale/model_selection/shuffle_split.rb
|
142
142
|
- lib/rumale/model_selection/stratified_k_fold.rb
|
143
143
|
- lib/rumale/model_selection/stratified_shuffle_split.rb
|
144
|
+
- lib/rumale/model_selection/time_series_split.rb
|
144
145
|
- lib/rumale/multiclass/one_vs_rest_classifier.rb
|
145
146
|
- lib/rumale/naive_bayes/base_naive_bayes.rb
|
146
147
|
- lib/rumale/naive_bayes/bernoulli_nb.rb
|