svmkit 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,58 +0,0 @@
1
- # This configuration was generated by
2
- # `rubocop --auto-gen-config`
3
- # on 2018-06-10 12:21:53 +0900 using RuboCop version 0.57.1.
4
- # The point is for the user to remove these configuration records
5
- # one by one as the offenses are removed from the code base.
6
- # Note that changes in the inspected code, or installation of new
7
- # versions of RuboCop, may require this file to be generated again.
8
-
9
- # Offense count: 2
10
- # Cop supports --auto-correct.
11
- Layout/ClosingHeredocIndentation:
12
- Exclude:
13
- - 'svmkit.gemspec'
14
-
15
- # Offense count: 2
16
- # Cop supports --auto-correct.
17
- # Configuration parameters: EnforcedStyle.
18
- # SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
19
- Layout/IndentHeredoc:
20
- Exclude:
21
- - 'svmkit.gemspec'
22
-
23
- # Offense count: 1
24
- # Cop supports --auto-correct.
25
- Layout/LeadingBlankLines:
26
- Exclude:
27
- - 'svmkit.gemspec'
28
-
29
- # Offense count: 1
30
- # Configuration parameters: CountComments, ExcludedMethods.
31
- Metrics/BlockLength:
32
- Max: 29
33
-
34
- # Offense count: 3
35
- Metrics/CyclomaticComplexity:
36
- Max: 12
37
-
38
- # Offense count: 3
39
- Metrics/PerceivedComplexity:
40
- Max: 13
41
-
42
- # Offense count: 1
43
- # Cop supports --auto-correct.
44
- # Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
45
- # SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
46
- Style/HashSyntax:
47
- Exclude:
48
- - 'Rakefile'
49
-
50
- # Offense count: 6
51
- # Cop supports --auto-correct.
52
- # Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
53
- # SupportedStyles: single_quotes, double_quotes
54
- Style/StringLiterals:
55
- Exclude:
56
- - 'Gemfile'
57
- - 'Rakefile'
58
- - 'bin/console'
data/HISTORY.md DELETED
@@ -1,168 +0,0 @@
1
- # 0.7.3
2
- - Add class for grid search performing hyperparameter optimization.
3
- - Add argument validations to Pipeline.
4
-
5
- # 0.7.2
6
- - Add class for Pipeline that constructs chain of transformers and estimators.
7
- - Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
8
-
9
- # 0.7.1
10
- - Fix to use CSV class in parsing libsvm format file.
11
- - Refactor ensemble estimators.
12
-
13
- # 0.7.0
14
- - Add class for AdaBoost classifier.
15
- - Add class for AdaBoost regressor.
16
-
17
- # 0.6.3
18
- - Fix bug on setting random seed and max_features parameter of Random Forest estimators.
19
-
20
- # 0.6.2
21
- - Refactor decision tree classes for improving performance.
22
-
23
- # 0.6.1
24
- - Add abstract class for linear estimator with stochastic gradient descent.
25
- - Refactor linear estimators to use linear esitmator abstract class.
26
- - Refactor decision tree classes to avoid unneeded type conversion.
27
-
28
- # 0.6.0
29
- - Add class for Principal Component Analysis.
30
- - Add class for Non-negative Matrix Factorization.
31
-
32
- # 0.5.2
33
- - Add class for DBSCAN clustering.
34
-
35
- # 0.5.1
36
- - Fix bug on class probability calculation of DecisionTreeClassifier.
37
-
38
- # 0.5.0
39
- - Add class for K-Means clustering.
40
- - Add class for evaluating purity.
41
- - Add class for evaluating normalized mutual information.
42
-
43
- # 0.4.1
44
- - Add class for linear regressor.
45
- - Add class for SGD optimizer.
46
- - Add class for RMSProp optimizer.
47
- - Add class for YellowFin optimizer.
48
- - Fix to be able to select optimizer on estimators of LineaModel and PolynomialModel.
49
-
50
- # 0.4.0
51
- ## Breaking changes
52
-
53
- SVMKit introduces optimizer algorithm that calculates learning rates adaptively
54
- on each iteration of stochastic gradient descent (SGD).
55
- While Pegasos SGD runs fast, it sometimes fails to optimize complicated models
56
- like Factorization Machine.
57
- To solve this problem, in version 0.3.3, SVMKit introduced optimization with RMSProp on
58
- FactorizationMachineRegressor, Ridge and Lasso.
59
- This attempt realized stable optimization of those estimators.
60
- Following the success of the attempt, author decided to use modern optimizer algorithms
61
- with all SGD optimizations in SVMKit.
62
- Through some preliminary experiments, author implemented Nadam as the default optimizer.
63
- SVMKit plans to add other optimizer algorithms sequentially, so that users can select them.
64
-
65
- - Fix to use Nadam for optimization on SVC, SVR, LogisticRegression, Ridge, Lasso, and Factorization Machine estimators.
66
- - Combine reg_param_weight and reg_param_bias parameters on Factorization Machine estimators into the unified parameter named reg_param_linear.
67
- - Remove init_std paramter on Factorization Machine estimators.
68
- - Remove learning_rate, decay, and momentum parameters on Ridge, Lasso, and FactorizationMachineRegressor.
69
- - Remove normalize parameter on SVC, SVR, and LogisticRegression.
70
-
71
- # 0.3.3
72
- - Add class for Ridge regressor.
73
- - Add class for Lasso regressor.
74
- - Fix bug on gradient calculation of FactorizationMachineRegressor.
75
- - Fix some documents.
76
-
77
- # 0.3.2
78
- - Add class for Factorization Machine regressor.
79
- - Add class for Decision Tree regressor.
80
- - Add class for Random Forest regressor.
81
- - Fix to support loading and dumping libsvm file with multi-target variables.
82
- - Fix to require DecisionTreeClassifier on RandomForestClassifier.
83
- - Fix some mistakes on document.
84
-
85
- # 0.3.1
86
- - Fix bug on decision function calculation of FactorizationMachineClassifier.
87
- - Fix bug on weight updating process of KernelSVC.
88
-
89
- # 0.3.0
90
- - Add class for Support Vector Regression.
91
- - Add class for K-Nearest Neighbor Regression.
92
- - Add class for evaluating coefficient of determination.
93
- - Add class for evaluating mean squared error.
94
- - Add class for evaluating mean absolute error.
95
- - Fix to use min method instead of sort and first methods.
96
- - Fix cross validation class to be able to use for regression problem.
97
- - Fix some typos on document.
98
- - Rename spec filename for Factorization Machine classifier.
99
-
100
- # 0.2.9
101
- - Add predict_proba method to SVC and KernelSVC.
102
- - Add class for evaluating logarithmic loss.
103
- - Add classes for Label- and One-Hot- encoding.
104
- - Add some validator.
105
- - Fix bug on training data score calculation of cross validation.
106
- - Fix fit method of SVC for performance.
107
- - Fix criterion calculation on Decision Tree for performance.
108
- - Fix data structure of Decision Tree for performance.
109
-
110
- # 0.2.8
111
- - Fix bug on gradient calculation of Logistic Regression.
112
- - Fix to change accessor of params of estimators to read only.
113
- - Add parameter validation.
114
-
115
- # 0.2.7
116
- - Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
117
-
118
- # 0.2.6
119
- - Add class for Decision Tree classifier.
120
- - Add class for Random Forest classifier.
121
- - Fix to use frozen string literal.
122
- - Refactor marshal dump method on some classes.
123
- - Introduce Coveralls to confirm test coverage.
124
-
125
- # 0.2.5
126
- - Add classes for Naive Bayes classifier.
127
- - Fix decision function method on Logistic Regression class.
128
- - Fix method visibility on RBF kernel approximation class.
129
-
130
- # 0.2.4
131
- - Add class for Factorization Machine classifier.
132
- - Add classes for evaluation measures.
133
- - Fix the method for prediction of class probability in Logistic Regression.
134
-
135
- # 0.2.3
136
- - Add class for cross validation.
137
- - Add specs for base modules.
138
- - Fix validation of the number of splits when a negative label is given.
139
-
140
- # 0.2.2
141
- - Add data splitter classes for K-fold cross validation.
142
-
143
- # 0.2.1
144
- - Add class for K-nearest neighbors classifier.
145
-
146
- # 0.2.0
147
- - Migrated the linear algebra library to Numo::NArray.
148
- - Add module for loading and saving libsvm format file.
149
-
150
- # 0.1.3
151
- - Add class for Kernel Support Vector Machine with Pegasos algorithm.
152
- - Add module for calculating pairwise kernel fuctions and euclidean distances.
153
-
154
- # 0.1.2
155
- - Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
156
- - Rewrite the document with yard notation.
157
-
158
- # 0.1.1
159
- - Add class for Logistic Regression with SGD optimization.
160
- - Fix some mistakes on the document.
161
-
162
- # 0.1.0
163
- - Add basic classes.
164
- - Add an utility module.
165
- - Add class for RBF kernel approximation.
166
- - Add class for Support Vector Machine with Pegasos alogrithm.
167
- - Add class that performs mutlclass classification with one-vs.-rest strategy.
168
- - Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
@@ -1,13 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- # This module consists of basic mix-in classes.
5
- module Base
6
- # Base module for all estimators in SVMKit.
7
- module BaseEstimator
8
- # Return parameters about an estimator.
9
- # @return [Hash]
10
- attr_reader :params
11
- end
12
- end
13
- end
@@ -1,34 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/accuracy'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all classifiers in SVMKit.
9
- module Classifier
10
- # An abstract method for fitting a model.
11
- def fit
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # An abstract method for predicting labels.
16
- def predict
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
-
20
- # Calculate the mean accuracy of the given testing data.
21
- #
22
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
23
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
24
- # @return [Float] Mean accuracy
25
- def score(x, y)
26
- SVMKit::Validation.check_sample_array(x)
27
- SVMKit::Validation.check_label_array(y)
28
- SVMKit::Validation.check_sample_label_size(x, y)
29
- evaluator = SVMKit::EvaluationMeasure::Accuracy.new
30
- evaluator.score(y, predict(x))
31
- end
32
- end
33
- end
34
- end
@@ -1,29 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/purity'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all clustering algorithms in SVMKit.
9
- module ClusterAnalyzer
10
- # An abstract method for analyzing clusters and predicting cluster indices.
11
- def fit_predict
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # Calculate purity of clustering result.
16
- #
17
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
18
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
19
- # @return [Float] Purity
20
- def score(x, y)
21
- SVMKit::Validation.check_sample_array(x)
22
- SVMKit::Validation.check_label_array(y)
23
- SVMKit::Validation.check_sample_label_size(x, y)
24
- evaluator = SVMKit::EvaluationMeasure::Purity.new
25
- evaluator.score(y, fit_predict(x))
26
- end
27
- end
28
- end
29
- end
@@ -1,13 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all evaluation measures in SVMKit.
6
- module Evaluator
7
- # An abstract method for evaluation of model.
8
- def score
9
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
10
- end
11
- end
12
- end
13
- end
@@ -1,34 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/r2_score'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all regressors in SVMKit.
9
- module Regressor
10
- # An abstract method for fitting a model.
11
- def fit
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # An abstract method for predicting labels.
16
- def predict
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
-
20
- # Calculate the coefficient of determination for the given testing data.
21
- #
22
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
23
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
24
- # @return [Float] Coefficient of determination
25
- def score(x, y)
26
- SVMKit::Validation.check_sample_array(x)
27
- SVMKit::Validation.check_tvalue_array(y)
28
- SVMKit::Validation.check_sample_tvalue_size(x, y)
29
- evaluator = SVMKit::EvaluationMeasure::R2Score.new
30
- evaluator.score(y, predict(x))
31
- end
32
- end
33
- end
34
- end
@@ -1,17 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all validation methods in SVMKit.
6
- module Splitter
7
- # Return the number of splits.
8
- # @return [Integer]
9
- attr_reader :n_splits
10
-
11
- # An abstract method for splitting dataset.
12
- def split
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
- end
16
- end
17
- end
@@ -1,18 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all transfomers in SVMKit.
6
- module Transformer
7
- # An abstract method for fitting a model.
8
- def fit
9
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
10
- end
11
-
12
- # An abstract method for fitting a model and transforming given data.
13
- def fit_transform
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
- end
17
- end
18
- end
@@ -1,127 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/cluster_analyzer'
6
- require 'svmkit/pairwise_metric'
7
-
8
- module SVMKit
9
- module Clustering
10
- # DBSCAN is a class that implements DBSCAN cluster analysis.
11
- # The current implementation uses the Euclidean distance for analyzing the clusters.
12
- #
13
- # @example
14
- # analyzer = SVMKit::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
15
- # cluster_labels = analyzer.fit_predict(samples)
16
- #
17
- # *Reference*
18
- # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
19
- class DBSCAN
20
- include Base::BaseEstimator
21
- include Base::ClusterAnalyzer
22
- include Validation
23
-
24
- # Return the core sample indices.
25
- # @return [Numo::Int32] (shape: [n_core_samples])
26
- attr_reader :core_sample_ids
27
-
28
- # Return the cluster labels. The negative cluster label indicates that the point is noise.
29
- # @return [Numo::Int32] (shape: [n_samples])
30
- attr_reader :labels
31
-
32
- # Create a new cluster analyzer with DBSCAN method.
33
- #
34
- # @param eps [Float] The radius of neighborhood.
35
- # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
36
- def initialize(eps: 0.5, min_samples: 5)
37
- check_params_float(eps: eps)
38
- check_params_integer(min_samples: min_samples)
39
- @params = {}
40
- @params[:eps] = eps
41
- @params[:min_samples] = min_samples
42
- @core_sample_ids = nil
43
- @labels = nil
44
- end
45
-
46
- # Analysis clusters with given training data.
47
- #
48
- # @overload fit(x) -> DBSCAN
49
- #
50
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
51
- # @return [DBSCAN] The learned cluster analyzer itself.
52
- def fit(x, _y = nil)
53
- check_sample_array(x)
54
- partial_fit(x)
55
- self
56
- end
57
-
58
- # Analysis clusters and assign samples to clusters.
59
- #
60
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
61
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
62
- def fit_predict(x)
63
- check_sample_array(x)
64
- partial_fit(x)
65
- labels
66
- end
67
-
68
- # Dump marshal data.
69
- # @return [Hash] The marshal data.
70
- def marshal_dump
71
- { params: @params,
72
- core_sample_ids: @core_sample_ids,
73
- labels: @labels }
74
- end
75
-
76
- # Load marshal data.
77
- # @return [nil]
78
- def marshal_load(obj)
79
- @params = obj[:params]
80
- @core_sample_ids = obj[:core_sample_ids]
81
- @labels = obj[:labels]
82
- nil
83
- end
84
-
85
- private
86
-
87
- def partial_fit(x)
88
- cluster_id = 0
89
- n_samples = x.shape[0]
90
- @core_sample_ids = []
91
- @labels = Numo::Int32.zeros(n_samples) - 2
92
- n_samples.times do |q|
93
- next if @labels[q] >= -1
94
- cluster_id += 1 if expand_cluster(x, q, cluster_id)
95
- end
96
- @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
97
- nil
98
- end
99
-
100
- def expand_cluster(x, query_id, cluster_id)
101
- target_ids = region_query(x[query_id, true], x)
102
- if target_ids.size < @params[:min_samples]
103
- @labels[query_id] = -1
104
- false
105
- else
106
- @labels[target_ids] = cluster_id
107
- @core_sample_ids.push(target_ids.dup)
108
- target_ids.delete(query_id)
109
- while (m = target_ids.shift)
110
- neighbor_ids = region_query(x[m, true], x)
111
- next if neighbor_ids.size < @params[:min_samples]
112
- neighbor_ids.each do |n|
113
- target_ids.push(n) if @labels[n] < -1
114
- @labels[n] = cluster_id if @labels[n] <= -1
115
- end
116
- end
117
- true
118
- end
119
- end
120
-
121
- def region_query(query, targets)
122
- distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
123
- distance_arr.lt(@params[:eps]).where.to_a
124
- end
125
- end
126
- end
127
- end