svmkit 0.7.3 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +0 -9
  3. data/.rspec +1 -0
  4. data/.travis.yml +4 -12
  5. data/LICENSE.txt +1 -1
  6. data/README.md +11 -13
  7. data/lib/svmkit.rb +3 -66
  8. data/svmkit.gemspec +12 -7
  9. metadata +16 -81
  10. data/.coveralls.yml +0 -1
  11. data/.rubocop.yml +0 -47
  12. data/.rubocop_todo.yml +0 -58
  13. data/HISTORY.md +0 -168
  14. data/lib/svmkit/base/base_estimator.rb +0 -13
  15. data/lib/svmkit/base/classifier.rb +0 -34
  16. data/lib/svmkit/base/cluster_analyzer.rb +0 -29
  17. data/lib/svmkit/base/evaluator.rb +0 -13
  18. data/lib/svmkit/base/regressor.rb +0 -34
  19. data/lib/svmkit/base/splitter.rb +0 -17
  20. data/lib/svmkit/base/transformer.rb +0 -18
  21. data/lib/svmkit/clustering/dbscan.rb +0 -127
  22. data/lib/svmkit/clustering/k_means.rb +0 -140
  23. data/lib/svmkit/dataset.rb +0 -109
  24. data/lib/svmkit/decomposition/nmf.rb +0 -147
  25. data/lib/svmkit/decomposition/pca.rb +0 -150
  26. data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
  27. data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
  28. data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
  29. data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
  30. data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
  31. data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
  32. data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
  33. data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
  34. data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
  35. data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
  36. data/lib/svmkit/evaluation_measure/precision.rb +0 -51
  37. data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
  38. data/lib/svmkit/evaluation_measure/purity.rb +0 -41
  39. data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
  40. data/lib/svmkit/evaluation_measure/recall.rb +0 -51
  41. data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
  42. data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
  43. data/lib/svmkit/linear_model/lasso.rb +0 -138
  44. data/lib/svmkit/linear_model/linear_regression.rb +0 -112
  45. data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
  46. data/lib/svmkit/linear_model/ridge.rb +0 -112
  47. data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
  48. data/lib/svmkit/linear_model/svc.rb +0 -184
  49. data/lib/svmkit/linear_model/svr.rb +0 -123
  50. data/lib/svmkit/model_selection/cross_validation.rb +0 -121
  51. data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
  52. data/lib/svmkit/model_selection/k_fold.rb +0 -77
  53. data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
  54. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
  55. data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
  56. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
  57. data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
  58. data/lib/svmkit/optimizer/nadam.rb +0 -90
  59. data/lib/svmkit/optimizer/rmsprop.rb +0 -69
  60. data/lib/svmkit/optimizer/sgd.rb +0 -65
  61. data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
  62. data/lib/svmkit/pairwise_metric.rb +0 -91
  63. data/lib/svmkit/pipeline/pipeline.rb +0 -197
  64. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
  65. data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
  66. data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
  67. data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
  68. data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
  69. data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
  70. data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
  71. data/lib/svmkit/probabilistic_output.rb +0 -112
  72. data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
  73. data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
  74. data/lib/svmkit/tree/node.rb +0 -70
  75. data/lib/svmkit/utils.rb +0 -22
  76. data/lib/svmkit/validation.rb +0 -79
  77. data/lib/svmkit/values.rb +0 -13
  78. data/lib/svmkit/version.rb +0 -7
@@ -1,58 +0,0 @@
1
- # This configuration was generated by
2
- # `rubocop --auto-gen-config`
3
- # on 2018-06-10 12:21:53 +0900 using RuboCop version 0.57.1.
4
- # The point is for the user to remove these configuration records
5
- # one by one as the offenses are removed from the code base.
6
- # Note that changes in the inspected code, or installation of new
7
- # versions of RuboCop, may require this file to be generated again.
8
-
9
- # Offense count: 2
10
- # Cop supports --auto-correct.
11
- Layout/ClosingHeredocIndentation:
12
- Exclude:
13
- - 'svmkit.gemspec'
14
-
15
- # Offense count: 2
16
- # Cop supports --auto-correct.
17
- # Configuration parameters: EnforcedStyle.
18
- # SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
19
- Layout/IndentHeredoc:
20
- Exclude:
21
- - 'svmkit.gemspec'
22
-
23
- # Offense count: 1
24
- # Cop supports --auto-correct.
25
- Layout/LeadingBlankLines:
26
- Exclude:
27
- - 'svmkit.gemspec'
28
-
29
- # Offense count: 1
30
- # Configuration parameters: CountComments, ExcludedMethods.
31
- Metrics/BlockLength:
32
- Max: 29
33
-
34
- # Offense count: 3
35
- Metrics/CyclomaticComplexity:
36
- Max: 12
37
-
38
- # Offense count: 3
39
- Metrics/PerceivedComplexity:
40
- Max: 13
41
-
42
- # Offense count: 1
43
- # Cop supports --auto-correct.
44
- # Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
45
- # SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
46
- Style/HashSyntax:
47
- Exclude:
48
- - 'Rakefile'
49
-
50
- # Offense count: 6
51
- # Cop supports --auto-correct.
52
- # Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
53
- # SupportedStyles: single_quotes, double_quotes
54
- Style/StringLiterals:
55
- Exclude:
56
- - 'Gemfile'
57
- - 'Rakefile'
58
- - 'bin/console'
data/HISTORY.md DELETED
@@ -1,168 +0,0 @@
1
- # 0.7.3
2
- - Add class for grid search performing hyperparameter optimization.
3
- - Add argument validations to Pipeline.
4
-
5
- # 0.7.2
6
- - Add class for Pipeline that constructs chain of transformers and estimators.
7
- - Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
8
-
9
- # 0.7.1
10
- - Fix to use CSV class in parsing libsvm format file.
11
- - Refactor ensemble estimators.
12
-
13
- # 0.7.0
14
- - Add class for AdaBoost classifier.
15
- - Add class for AdaBoost regressor.
16
-
17
- # 0.6.3
18
- - Fix bug on setting random seed and max_features parameter of Random Forest estimators.
19
-
20
- # 0.6.2
21
- - Refactor decision tree classes for improving performance.
22
-
23
- # 0.6.1
24
- - Add abstract class for linear estimator with stochastic gradient descent.
25
- - Refactor linear estimators to use linear esitmator abstract class.
26
- - Refactor decision tree classes to avoid unneeded type conversion.
27
-
28
- # 0.6.0
29
- - Add class for Principal Component Analysis.
30
- - Add class for Non-negative Matrix Factorization.
31
-
32
- # 0.5.2
33
- - Add class for DBSCAN clustering.
34
-
35
- # 0.5.1
36
- - Fix bug on class probability calculation of DecisionTreeClassifier.
37
-
38
- # 0.5.0
39
- - Add class for K-Means clustering.
40
- - Add class for evaluating purity.
41
- - Add class for evaluating normalized mutual information.
42
-
43
- # 0.4.1
44
- - Add class for linear regressor.
45
- - Add class for SGD optimizer.
46
- - Add class for RMSProp optimizer.
47
- - Add class for YellowFin optimizer.
48
- - Fix to be able to select optimizer on estimators of LineaModel and PolynomialModel.
49
-
50
- # 0.4.0
51
- ## Breaking changes
52
-
53
- SVMKit introduces optimizer algorithm that calculates learning rates adaptively
54
- on each iteration of stochastic gradient descent (SGD).
55
- While Pegasos SGD runs fast, it sometimes fails to optimize complicated models
56
- like Factorization Machine.
57
- To solve this problem, in version 0.3.3, SVMKit introduced optimization with RMSProp on
58
- FactorizationMachineRegressor, Ridge and Lasso.
59
- This attempt realized stable optimization of those estimators.
60
- Following the success of the attempt, author decided to use modern optimizer algorithms
61
- with all SGD optimizations in SVMKit.
62
- Through some preliminary experiments, author implemented Nadam as the default optimizer.
63
- SVMKit plans to add other optimizer algorithms sequentially, so that users can select them.
64
-
65
- - Fix to use Nadam for optimization on SVC, SVR, LogisticRegression, Ridge, Lasso, and Factorization Machine estimators.
66
- - Combine reg_param_weight and reg_param_bias parameters on Factorization Machine estimators into the unified parameter named reg_param_linear.
67
- - Remove init_std paramter on Factorization Machine estimators.
68
- - Remove learning_rate, decay, and momentum parameters on Ridge, Lasso, and FactorizationMachineRegressor.
69
- - Remove normalize parameter on SVC, SVR, and LogisticRegression.
70
-
71
- # 0.3.3
72
- - Add class for Ridge regressor.
73
- - Add class for Lasso regressor.
74
- - Fix bug on gradient calculation of FactorizationMachineRegressor.
75
- - Fix some documents.
76
-
77
- # 0.3.2
78
- - Add class for Factorization Machine regressor.
79
- - Add class for Decision Tree regressor.
80
- - Add class for Random Forest regressor.
81
- - Fix to support loading and dumping libsvm file with multi-target variables.
82
- - Fix to require DecisionTreeClassifier on RandomForestClassifier.
83
- - Fix some mistakes on document.
84
-
85
- # 0.3.1
86
- - Fix bug on decision function calculation of FactorizationMachineClassifier.
87
- - Fix bug on weight updating process of KernelSVC.
88
-
89
- # 0.3.0
90
- - Add class for Support Vector Regression.
91
- - Add class for K-Nearest Neighbor Regression.
92
- - Add class for evaluating coefficient of determination.
93
- - Add class for evaluating mean squared error.
94
- - Add class for evaluating mean absolute error.
95
- - Fix to use min method instead of sort and first methods.
96
- - Fix cross validation class to be able to use for regression problem.
97
- - Fix some typos on document.
98
- - Rename spec filename for Factorization Machine classifier.
99
-
100
- # 0.2.9
101
- - Add predict_proba method to SVC and KernelSVC.
102
- - Add class for evaluating logarithmic loss.
103
- - Add classes for Label- and One-Hot- encoding.
104
- - Add some validator.
105
- - Fix bug on training data score calculation of cross validation.
106
- - Fix fit method of SVC for performance.
107
- - Fix criterion calculation on Decision Tree for performance.
108
- - Fix data structure of Decision Tree for performance.
109
-
110
- # 0.2.8
111
- - Fix bug on gradient calculation of Logistic Regression.
112
- - Fix to change accessor of params of estimators to read only.
113
- - Add parameter validation.
114
-
115
- # 0.2.7
116
- - Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
117
-
118
- # 0.2.6
119
- - Add class for Decision Tree classifier.
120
- - Add class for Random Forest classifier.
121
- - Fix to use frozen string literal.
122
- - Refactor marshal dump method on some classes.
123
- - Introduce Coveralls to confirm test coverage.
124
-
125
- # 0.2.5
126
- - Add classes for Naive Bayes classifier.
127
- - Fix decision function method on Logistic Regression class.
128
- - Fix method visibility on RBF kernel approximation class.
129
-
130
- # 0.2.4
131
- - Add class for Factorization Machine classifier.
132
- - Add classes for evaluation measures.
133
- - Fix the method for prediction of class probability in Logistic Regression.
134
-
135
- # 0.2.3
136
- - Add class for cross validation.
137
- - Add specs for base modules.
138
- - Fix validation of the number of splits when a negative label is given.
139
-
140
- # 0.2.2
141
- - Add data splitter classes for K-fold cross validation.
142
-
143
- # 0.2.1
144
- - Add class for K-nearest neighbors classifier.
145
-
146
- # 0.2.0
147
- - Migrated the linear algebra library to Numo::NArray.
148
- - Add module for loading and saving libsvm format file.
149
-
150
- # 0.1.3
151
- - Add class for Kernel Support Vector Machine with Pegasos algorithm.
152
- - Add module for calculating pairwise kernel fuctions and euclidean distances.
153
-
154
- # 0.1.2
155
- - Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
156
- - Rewrite the document with yard notation.
157
-
158
- # 0.1.1
159
- - Add class for Logistic Regression with SGD optimization.
160
- - Fix some mistakes on the document.
161
-
162
- # 0.1.0
163
- - Add basic classes.
164
- - Add an utility module.
165
- - Add class for RBF kernel approximation.
166
- - Add class for Support Vector Machine with Pegasos alogrithm.
167
- - Add class that performs mutlclass classification with one-vs.-rest strategy.
168
- - Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
@@ -1,13 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- # This module consists of basic mix-in classes.
5
- module Base
6
- # Base module for all estimators in SVMKit.
7
- module BaseEstimator
8
- # Return parameters about an estimator.
9
- # @return [Hash]
10
- attr_reader :params
11
- end
12
- end
13
- end
@@ -1,34 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/accuracy'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all classifiers in SVMKit.
9
- module Classifier
10
- # An abstract method for fitting a model.
11
- def fit
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # An abstract method for predicting labels.
16
- def predict
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
-
20
- # Calculate the mean accuracy of the given testing data.
21
- #
22
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
23
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
24
- # @return [Float] Mean accuracy
25
- def score(x, y)
26
- SVMKit::Validation.check_sample_array(x)
27
- SVMKit::Validation.check_label_array(y)
28
- SVMKit::Validation.check_sample_label_size(x, y)
29
- evaluator = SVMKit::EvaluationMeasure::Accuracy.new
30
- evaluator.score(y, predict(x))
31
- end
32
- end
33
- end
34
- end
@@ -1,29 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/purity'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all clustering algorithms in SVMKit.
9
- module ClusterAnalyzer
10
- # An abstract method for analyzing clusters and predicting cluster indices.
11
- def fit_predict
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # Calculate purity of clustering result.
16
- #
17
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
18
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
19
- # @return [Float] Purity
20
- def score(x, y)
21
- SVMKit::Validation.check_sample_array(x)
22
- SVMKit::Validation.check_label_array(y)
23
- SVMKit::Validation.check_sample_label_size(x, y)
24
- evaluator = SVMKit::EvaluationMeasure::Purity.new
25
- evaluator.score(y, fit_predict(x))
26
- end
27
- end
28
- end
29
- end
@@ -1,13 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all evaluation measures in SVMKit.
6
- module Evaluator
7
- # An abstract method for evaluation of model.
8
- def score
9
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
10
- end
11
- end
12
- end
13
- end
@@ -1,34 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/evaluation_measure/r2_score'
5
-
6
- module SVMKit
7
- module Base
8
- # Module for all regressors in SVMKit.
9
- module Regressor
10
- # An abstract method for fitting a model.
11
- def fit
12
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
- end
14
-
15
- # An abstract method for predicting labels.
16
- def predict
17
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
18
- end
19
-
20
- # Calculate the coefficient of determination for the given testing data.
21
- #
22
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
23
- # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
24
- # @return [Float] Coefficient of determination
25
- def score(x, y)
26
- SVMKit::Validation.check_sample_array(x)
27
- SVMKit::Validation.check_tvalue_array(y)
28
- SVMKit::Validation.check_sample_tvalue_size(x, y)
29
- evaluator = SVMKit::EvaluationMeasure::R2Score.new
30
- evaluator.score(y, predict(x))
31
- end
32
- end
33
- end
34
- end
@@ -1,17 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all validation methods in SVMKit.
6
- module Splitter
7
- # Return the number of splits.
8
- # @return [Integer]
9
- attr_reader :n_splits
10
-
11
- # An abstract method for splitting dataset.
12
- def split
13
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
- end
15
- end
16
- end
17
- end
@@ -1,18 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module SVMKit
4
- module Base
5
- # Module for all transfomers in SVMKit.
6
- module Transformer
7
- # An abstract method for fitting a model.
8
- def fit
9
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
10
- end
11
-
12
- # An abstract method for fitting a model and transforming given data.
13
- def fit_transform
14
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
15
- end
16
- end
17
- end
18
- end
@@ -1,127 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'svmkit/validation'
4
- require 'svmkit/base/base_estimator'
5
- require 'svmkit/base/cluster_analyzer'
6
- require 'svmkit/pairwise_metric'
7
-
8
- module SVMKit
9
- module Clustering
10
- # DBSCAN is a class that implements DBSCAN cluster analysis.
11
- # The current implementation uses the Euclidean distance for analyzing the clusters.
12
- #
13
- # @example
14
- # analyzer = SVMKit::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
15
- # cluster_labels = analyzer.fit_predict(samples)
16
- #
17
- # *Reference*
18
- # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
19
- class DBSCAN
20
- include Base::BaseEstimator
21
- include Base::ClusterAnalyzer
22
- include Validation
23
-
24
- # Return the core sample indices.
25
- # @return [Numo::Int32] (shape: [n_core_samples])
26
- attr_reader :core_sample_ids
27
-
28
- # Return the cluster labels. The negative cluster label indicates that the point is noise.
29
- # @return [Numo::Int32] (shape: [n_samples])
30
- attr_reader :labels
31
-
32
- # Create a new cluster analyzer with DBSCAN method.
33
- #
34
- # @param eps [Float] The radius of neighborhood.
35
- # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
36
- def initialize(eps: 0.5, min_samples: 5)
37
- check_params_float(eps: eps)
38
- check_params_integer(min_samples: min_samples)
39
- @params = {}
40
- @params[:eps] = eps
41
- @params[:min_samples] = min_samples
42
- @core_sample_ids = nil
43
- @labels = nil
44
- end
45
-
46
- # Analysis clusters with given training data.
47
- #
48
- # @overload fit(x) -> DBSCAN
49
- #
50
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
51
- # @return [DBSCAN] The learned cluster analyzer itself.
52
- def fit(x, _y = nil)
53
- check_sample_array(x)
54
- partial_fit(x)
55
- self
56
- end
57
-
58
- # Analysis clusters and assign samples to clusters.
59
- #
60
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
61
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
62
- def fit_predict(x)
63
- check_sample_array(x)
64
- partial_fit(x)
65
- labels
66
- end
67
-
68
- # Dump marshal data.
69
- # @return [Hash] The marshal data.
70
- def marshal_dump
71
- { params: @params,
72
- core_sample_ids: @core_sample_ids,
73
- labels: @labels }
74
- end
75
-
76
- # Load marshal data.
77
- # @return [nil]
78
- def marshal_load(obj)
79
- @params = obj[:params]
80
- @core_sample_ids = obj[:core_sample_ids]
81
- @labels = obj[:labels]
82
- nil
83
- end
84
-
85
- private
86
-
87
- def partial_fit(x)
88
- cluster_id = 0
89
- n_samples = x.shape[0]
90
- @core_sample_ids = []
91
- @labels = Numo::Int32.zeros(n_samples) - 2
92
- n_samples.times do |q|
93
- next if @labels[q] >= -1
94
- cluster_id += 1 if expand_cluster(x, q, cluster_id)
95
- end
96
- @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
97
- nil
98
- end
99
-
100
- def expand_cluster(x, query_id, cluster_id)
101
- target_ids = region_query(x[query_id, true], x)
102
- if target_ids.size < @params[:min_samples]
103
- @labels[query_id] = -1
104
- false
105
- else
106
- @labels[target_ids] = cluster_id
107
- @core_sample_ids.push(target_ids.dup)
108
- target_ids.delete(query_id)
109
- while (m = target_ids.shift)
110
- neighbor_ids = region_query(x[m, true], x)
111
- next if neighbor_ids.size < @params[:min_samples]
112
- neighbor_ids.each do |n|
113
- target_ids.push(n) if @labels[n] < -1
114
- @labels[n] = cluster_id if @labels[n] <= -1
115
- end
116
- end
117
- true
118
- end
119
- end
120
-
121
- def region_query(query, targets)
122
- distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
123
- distance_arr.lt(@params[:eps]).where.to_a
124
- end
125
- end
126
- end
127
- end