svmkit 0.5.1 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8f4ee565e18136b7f40832368ef78df514b7390a20929d40efb623d2ba7c0378
4
- data.tar.gz: e05f3ff80b02ee41a7ce4c32cf6bc6cc99f30771ae9f719eb4fea716680da229
3
+ metadata.gz: 917f85878296b940b497f13253e3d3b03047be8f154d554116c2629aaeea55dd
4
+ data.tar.gz: 16308e4638b15a55843f15b4e0d97886f27aae0cc236c59c590a8f9fe7f0e5c6
5
5
  SHA512:
6
- metadata.gz: 48cc0e18b0aa8a5ace9ceb07744249b799f61ffc04b177bbbb229529754c6e138f13dcd9f5ff04c2b61a380abcef01a6d0f4c175a86b7829c00aad9b91181521
7
- data.tar.gz: 2b84b8983d392015dc30cb69851e4df11c4d1b30e6e521763179e648e239482c0fb508ccacf46e03cbbe8f2765b123df4a595a3f7f2a2384b0d1c0f085cd78fd
6
+ metadata.gz: d390d3ef0d7b06676e6d3c34479939b4a99ee01472816eacbe49fd3f40224ef5984620dfe6d335fb5b15e7213d3b0d17ba9441766e7cdd08c8bad9bff669db8d
7
+ data.tar.gz: ab2239c0d1297e18e31940e763875ac24668d8c4c3f30355f06bc5ed305c247ff0328e1d584c5ab70ce77d4d2f946dcc5f72f1eb4c3a25d9b0dcd38e1d246182
data/HISTORY.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.5.2
2
+ - Add class for DBSCAN clustering.
3
+
1
4
  # 0.5.1
2
5
  - Fix bug on class probability calculation of DecisionTreeClassifier.
3
6
 
data/README.md CHANGED
@@ -10,7 +10,7 @@ SVMKit provides machine learning algorithms with interfaces similar to Scikit-Le
10
10
  SVMKit currently supports Linear / Kernel Support Vector Machine,
11
11
  Logistic Regression, Linear Regression, Ridge, Lasso, Factorization Machine,
12
12
  Naive Bayes, Decision Tree, Random Forest, K-nearest neighbor classifier,
13
- K-Means and cross-validation.
13
+ K-Means, DBSCAN and cross-validation.
14
14
 
15
15
  ## Installation
16
16
 
@@ -38,6 +38,7 @@ require 'svmkit/tree/decision_tree_regressor'
38
38
  require 'svmkit/ensemble/random_forest_classifier'
39
39
  require 'svmkit/ensemble/random_forest_regressor'
40
40
  require 'svmkit/clustering/k_means'
41
+ require 'svmkit/clustering/dbscan'
41
42
  require 'svmkit/preprocessing/l2_normalizer'
42
43
  require 'svmkit/preprocessing/min_max_scaler'
43
44
  require 'svmkit/preprocessing/standard_scaler'
@@ -0,0 +1,127 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'svmkit/validation'
4
+ require 'svmkit/base/base_estimator'
5
+ require 'svmkit/base/cluster_analyzer'
6
+ require 'svmkit/pairwise_metric'
7
+
8
+ module SVMKit
9
+ module Clustering
10
+ # DBSCAN is a class that implements DBSCAN cluster analysis.
11
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
12
+ #
13
+ # @example
14
+ # analyzer = SVMKit::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
15
+ # cluster_labels = analyzer.fit_predict(samples)
16
+ #
17
+ # *Reference*
18
+ # - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
19
+ class DBSCAN
20
+ include Base::BaseEstimator
21
+ include Base::ClusterAnalyzer
22
+ include Validation
23
+
24
+ # Return the core sample indices.
25
+ # @return [Numo::Int32] (shape: [n_core_samples])
26
+ attr_reader :core_sample_ids
27
+
28
+ # Return the cluster labels. The negative cluster label indicates that the point is noise.
29
+ # @return [Numo::Int32] (shape: [n_samples])
30
+ attr_reader :labels
31
+
32
+ # Create a new cluster analyzer with DBSCAN method.
33
+ #
34
+ # @param eps [Float] The radius of neighborhood.
35
+ # @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
36
+ def initialize(eps: 0.5, min_samples: 5)
37
+ check_params_float(eps: eps)
38
+ check_params_integer(min_samples: min_samples)
39
+ @params = {}
40
+ @params[:eps] = eps
41
+ @params[:min_samples] = min_samples
42
+ @core_sample_ids = nil
43
+ @labels = nil
44
+ end
45
+
46
+ # Analysis clusters with given training data.
47
+ #
48
+ # @overload fit(x) -> DBSCAN
49
+ #
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
51
+ # @return [DBSCAN] The learned cluster analyzer itself.
52
+ def fit(x, _y = nil)
53
+ check_sample_array(x)
54
+ partial_fit(x)
55
+ self
56
+ end
57
+
58
+ # Analysis clusters and assign samples to clusters.
59
+ #
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
61
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
62
+ def fit_predict(x)
63
+ check_sample_array(x)
64
+ partial_fit(x)
65
+ labels
66
+ end
67
+
68
+ # Dump marshal data.
69
+ # @return [Hash] The marshal data.
70
+ def marshal_dump
71
+ { params: @params,
72
+ core_sample_ids: @core_sample_ids,
73
+ labels: @labels }
74
+ end
75
+
76
+ # Load marshal data.
77
+ # @return [nil]
78
+ def marshal_load(obj)
79
+ @params = obj[:params]
80
+ @core_sample_ids = obj[:core_sample_ids]
81
+ @labels = obj[:labels]
82
+ nil
83
+ end
84
+
85
+ private
86
+
87
+ def partial_fit(x)
88
+ cluster_id = 0
89
+ n_samples = x.shape[0]
90
+ @core_sample_ids = []
91
+ @labels = Numo::Int32.zeros(n_samples) - 2
92
+ n_samples.times do |q|
93
+ next if @labels[q] >= -1
94
+ cluster_id += 1 if expand_cluster(x, q, cluster_id)
95
+ end
96
+ @core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
97
+ nil
98
+ end
99
+
100
+ def expand_cluster(x, query_id, cluster_id)
101
+ target_ids = region_query(x[query_id, true], x)
102
+ if target_ids.size < @params[:min_samples]
103
+ @labels[query_id] = -1
104
+ false
105
+ else
106
+ @labels[target_ids] = cluster_id
107
+ @core_sample_ids.push(target_ids.dup)
108
+ target_ids.delete(query_id)
109
+ while (m = target_ids.shift)
110
+ neighbor_ids = region_query(x[m, true], x)
111
+ next if neighbor_ids.size < @params[:min_samples]
112
+ neighbor_ids.each do |n|
113
+ target_ids.push(n) if @labels[n] < -1
114
+ @labels[n] = cluster_id if @labels[n] <= -1
115
+ end
116
+ end
117
+ true
118
+ end
119
+ end
120
+
121
+ def region_query(query, targets)
122
+ distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
123
+ distance_arr.lt(@params[:eps]).where.to_a
124
+ end
125
+ end
126
+ end
127
+ end
@@ -9,10 +9,11 @@ module SVMKit
9
9
  # This module consists of classes that implement cluster analysis methods.
10
10
  module Clustering
11
11
  # KMeans is a class that implements K-Means cluster analysis.
12
+ # The current implementation uses the Euclidean distance for analyzing the clusters.
12
13
  #
13
14
  # @example
14
15
  # analyzer = SVMKit::Clustering::KMeans.new(n_clusters: 10, max_iter: 50)
15
- # cluster_ids = analyzer.fit_predict(samples)
16
+ # cluster_labels = analyzer.fit_predict(samples)
16
17
  #
17
18
  # *Reference*
18
19
  # - D. Arthur and S. Vassilvitskii, "k-means++: the advantages of careful seeding," Proc. SODA'07, pp. 1027--1035, 2007.
@@ -38,6 +39,7 @@ module SVMKit
38
39
  # @param random_seed [Integer] The seed value using to initialize the random generator.
39
40
  def initialize(n_clusters: 8, init: 'k-means++', max_iter: 50, tol: 1.0e-4, random_seed: nil)
40
41
  check_params_integer(n_clusters: n_clusters, max_iter: max_iter)
42
+ check_params_float(tol: tol)
41
43
  check_params_string(init: init)
42
44
  check_params_type_or_nil(Integer, random_seed: random_seed)
43
45
  check_params_positive(n_clusters: n_clusters, max_iter: max_iter)
@@ -62,10 +64,10 @@ module SVMKit
62
64
  check_sample_array(x)
63
65
  init_cluster_centers(x)
64
66
  @params[:max_iter].times do |_t|
65
- cluster_ids = assign_cluster(x)
67
+ cluster_labels = assign_cluster(x)
66
68
  old_centers = @cluster_centers.dup
67
69
  @params[:n_clusters].times do |n|
68
- assigned_bits = cluster_ids.eq(n)
70
+ assigned_bits = cluster_labels.eq(n)
69
71
  @cluster_centers[n, true] = x[assigned_bits.where, true].mean(axis: 0) if assigned_bits.count > 0
70
72
  end
71
73
  error = Numo::NMath.sqrt(((old_centers - @cluster_centers)**2).sum(axis: 1)).mean
@@ -74,10 +76,10 @@ module SVMKit
74
76
  self
75
77
  end
76
78
 
77
- # Predict cluster indices for samples.
79
+ # Predict cluster labels for samples.
78
80
  #
79
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster index.
80
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample.
81
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the cluster label.
82
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
81
83
  def predict(x)
82
84
  check_sample_array(x)
83
85
  assign_cluster(x)
@@ -86,7 +88,7 @@ module SVMKit
86
88
  # Analysis clusters and assign samples to clusters.
87
89
  #
88
90
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
89
- # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster index per sample.
91
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
90
92
  def fit_predict(x)
91
93
  check_sample_array(x)
92
94
  fit(x)
@@ -3,5 +3,5 @@
3
3
  # SVMKit is a machine learning library in Ruby.
4
4
  module SVMKit
5
5
  # @!visibility private
6
- VERSION = '0.5.1'.freeze
6
+ VERSION = '0.5.2'.freeze
7
7
  end
@@ -18,7 +18,7 @@ SVMKit provides machine learning algorithms with interfaces similar to Scikit-Le
18
18
  SVMKit currently supports Linear / Kernel Support Vector Machine,
19
19
  Logistic Regression, Linear Regression, Ridge, Lasso, Factorization Machine,
20
20
  Naive Bayes, Decision Tree, Random Forest, K-nearest neighbor algorithm,
21
- K-Means and cross-validation.
21
+ K-Means, DBSCAN and cross-validation.
22
22
  MSG
23
23
  spec.homepage = 'https://github.com/yoshoku/svmkit'
24
24
  spec.license = 'BSD-2-Clause'
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.1
4
+ version: 0.5.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-06-16 00:00:00.000000000 Z
11
+ date: 2018-06-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -86,7 +86,7 @@ description: |
86
86
  SVMKit currently supports Linear / Kernel Support Vector Machine,
87
87
  Logistic Regression, Linear Regression, Ridge, Lasso, Factorization Machine,
88
88
  Naive Bayes, Decision Tree, Random Forest, K-nearest neighbor algorithm,
89
- K-Means and cross-validation.
89
+ K-Means, DBSCAN and cross-validation.
90
90
  email:
91
91
  - yoshoku@outlook.com
92
92
  executables: []
@@ -115,6 +115,7 @@ files:
115
115
  - lib/svmkit/base/regressor.rb
116
116
  - lib/svmkit/base/splitter.rb
117
117
  - lib/svmkit/base/transformer.rb
118
+ - lib/svmkit/clustering/dbscan.rb
118
119
  - lib/svmkit/clustering/k_means.rb
119
120
  - lib/svmkit/dataset.rb
120
121
  - lib/svmkit/ensemble/random_forest_classifier.rb