cross_validation 0.0.1 → 0.0.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: e36ebf38a97ebf8665474186f2c1ae3c82a3855d
4
- data.tar.gz: 770c131f2d8e3cf359e7e26450cb7a24c704239c
3
+ metadata.gz: 5f5fee4f5040fd1ad3562518031b850f7cf8c4b0
4
+ data.tar.gz: 7076cc8ff37d8967c84f04a493a5dff8729e0202
5
5
  SHA512:
6
- metadata.gz: 34ce7b4484db03a2d09aeb2b1f40c586a84a63b6ba22d355220ee6cba8bf588f3c8a2ac867a82524d43aa9fba06863f4f44ca09af8a17cb7f56409d9105dce7f
7
- data.tar.gz: d55546305d845c1f2c977f825f2d69a77b32369c3afc13b5262dd170b523f4efffa6f7fa9f6d9f7d9215b6b008156ff1ae2cb13993625268bc8d3e1f0b4b2f4c
6
+ metadata.gz: f27ea09189c51f97db89aaee88dee9cca8adefbc503fe596f91297ed9acc6273be0c3414dfd6b1004b9e32ff2cb98cc22edafaa3484ddf062e965a609bdff39b
7
+ data.tar.gz: 04e0bf2d0863102dadddcb053d053c8d4dd5bbb0098acc58b3205a09438d56c5c3077aa3529302017f709f1bab8cf7a7440b690d047b6bd1cb36277daf8eb810
data/README.md CHANGED
@@ -4,7 +4,9 @@
4
4
  [![Code Climate](https://codeclimate.com/github/jmdeldin/cross_validation.png)](https://codeclimate.com/github/jmdeldin/cross_validation)
5
5
 
6
6
  This gem provides a k-fold cross-validation routine and confusion matrix
7
- for evaluating machine learning classifiers.
7
+ for evaluating machine learning classifiers. See [below](#usage) for
8
+ usage or jump to the
9
+ [documentation](http://rubydoc.info/github/jmdeldin/cross_validation/frames).
8
10
 
9
11
  ## Installation
10
12
 
@@ -22,19 +24,89 @@ Or install it yourself as:
22
24
 
23
25
  ## Usage
24
26
 
25
- Cross-validation:
26
-
27
- Confusion-matrix:
28
-
29
-
30
- ## Contributing
31
-
32
- 1. Fork it
33
- 2. Create your feature branch (`git checkout -b my-new-feature`)
34
- 3. Commit your changes (`git commit -am 'Add some feature'`)
35
- 4. Push to the branch (`git push origin my-new-feature`)
36
- 5. Create new Pull Request
37
-
38
- ## Questions
39
-
40
- Send me an email, `dev@jmdeldin.com`
27
+ To cross-validate your classifier, you need to configure a run as
28
+ follows:
29
+
30
+ ```ruby
31
+ require 'cross_validation'
32
+
33
+ runner = CrossValidation::Runner.create do |r|
34
+ r.documents = my_array_of_documents
35
+ r.folds = 10
36
+ # or if you'd rather test on 10%
37
+ # r.percentage = 0.1
38
+ r.classifier = lambda { SpamClassifier.new }
39
+ r.fetch_sample_class = lambda { |sample| sample.klass }
40
+ r.fetch_sample_value = lambda { |sample| sample.value }
41
+ r.matrix = CrossValidation::ConfusionMatrix.new(method(:keys_for))
42
+ r.training = lambda { |classifier, doc|
43
+ classifier.train doc.klass, doc.value
44
+ }
45
+ r.classifying = lambda { |classifier, doc|
46
+ classifier.classify doc
47
+ }
48
+ end
49
+ ```
50
+
51
+ With the run configured, just invoke `#run` to return a confusion matrix:
52
+
53
+ ```ruby
54
+ mat = runner.run
55
+ ```
56
+
57
+ With a confusion matrix in hand, you can compute many statistics about
58
+ your classifier:
59
+
60
+ - `mat.accuracy`
61
+ - `mat.f1`
62
+ - `mat.fscore(beta)`
63
+ - `mat.precision`
64
+ - `mat.recall`
65
+
66
+ Please see the
67
+ [respective documentation](http://rubydoc.info/github/jmdeldin/cross_validation/CrossValidation/ConfusionMatrix)
68
+ for each method for more details.
69
+
70
+ ### Defining `keys_for`
71
+
72
+ The ConfusionMatrix class requires a `keys_for` `Proc` that returns a
73
+ symbol. In this method, you specify what constitutes a true positive
74
+ (`:tp`), true negative (`:tn`), false positive (`:fp`), and false
75
+ negative (`:fn`). For example, in spam classification, you can construct
76
+ the following table to write the keys_for method:
77
+
78
+ actual
79
+ +---------------------------------
80
+ expected | correct | not correct
81
+ ----------+----------------+----------------
82
+ spam | true positive | false positive
83
+ ham | true negative | false negative
84
+
85
+ You can then implement this table with nested hashes or just a few
86
+ conditionals:
87
+
88
+ ```ruby
89
+ def keys_for(expected, actual)
90
+ if expected == :spam
91
+ actual == :spam ? :tp : :fp
92
+ elsif expected == :ham
93
+ actual == :ham ? :tn : :fn
94
+ end
95
+ end
96
+ ```
97
+
98
+ Once you have your `keys_for` method implemented, pass it into the
99
+ ConfusionMatrix with `method(:keys_for)`, or if it's a class-method,
100
+ `MyClass.method(:keys_for)`. (You can also implement the method as a
101
+ lambda.)
102
+
103
+ ## Roadmap
104
+
105
+ For v1.0:
106
+
107
+ - Implement configurable, parallel cross-validation
108
+ - Include more complete examples
109
+
110
+ ## Author
111
+
112
+ Jon-Michael Deldin, `dev@jmdeldin.com`
@@ -1,5 +1,9 @@
1
1
  $LOAD_PATH.unshift File.dirname(__FILE__)
2
2
 
3
3
  module CrossValidation
4
- VERSION = '0.0.1'
4
+ VERSION = '0.0.2'
5
+ end
6
+
7
+ %w(confusion_matrix runner).each do |fn|
8
+ require File.join('cross_validation', fn)
5
9
  end
@@ -34,7 +34,7 @@ module CrossValidation
34
34
  # @param [Object] truth The known, expected value
35
35
  # @return [self]
36
36
  def store(actual, truth)
37
- key = @keys_for.call(actual, truth)
37
+ key = @keys_for.call(truth, actual)
38
38
 
39
39
  if @values.key?(key)
40
40
  @values[key] += 1
@@ -83,6 +83,8 @@ module CrossValidation
83
83
  end
84
84
 
85
85
  # Returns the classifier's error
86
+ #
87
+ # @return [Float]
86
88
  def error
87
89
  1.0 - accuracy()
88
90
  end
@@ -0,0 +1,34 @@
1
+ module CrossValidation
2
+ # Provides helper methods for data partitioning.
3
+ #
4
+ module Partitioner
5
+
6
+ # Splits the array into +k+-sized subsets.
7
+ #
8
+ # For example, calling this method for the array +%w(foo bar baz qux)+
9
+ # with +k=2+ results in an array of arrays: +[[foo, bar], [baz, qux]]+.
10
+ #
11
+ # @param [Array] ary documents to work with
12
+ # @param [Fixnum] k size of each subset
13
+ # @return [Array] array of arrays
14
+ # @raise [ArgumentError] if the length of the documents array is not
15
+ # evenly divisible by k
16
+ def self.subset(ary, k)
17
+ if ary.length % k > 0
18
+ fail ArgumentError, "Can't create equal subsets when k=#{k}"
19
+ end
20
+
21
+ ary.each_slice(k).to_a
22
+ end
23
+
24
+ # Returns a flattened copy of the original array without an element at
25
+ # index +i+.
26
+ #
27
+ # @param [Array] ary subsets to work with (e.g., array of arrays)
28
+ # @param [Fixnum] i index to remove
29
+ # @return [Array]
30
+ def self.exclude_index(ary, i)
31
+ ary.rotate(i).drop(1).flatten
32
+ end
33
+ end
34
+ end
@@ -1,4 +1,6 @@
1
1
  require_relative '../cross_validation'
2
+ require_relative 'partitioner'
3
+ require_relative 'sample'
2
4
 
3
5
  module CrossValidation
4
6
  class Runner
@@ -43,6 +45,17 @@ module CrossValidation
43
45
  # document and should return the document's class.
44
46
  attr_accessor :fetch_sample_class
45
47
 
48
+ # @return [Array] Array of which attributes are empty
49
+ attr_reader :errors
50
+
51
+ def initialize
52
+ @fetch_sample_value = lambda { |sample| sample.value }
53
+ @fetch_sample_class = lambda { |sample| sample.klass }
54
+
55
+ @critical_keys = [:documents, :classifier, :matrix, :training,
56
+ :classifying, :fetch_sample_value, :fetch_sample_class]
57
+ end
58
+
46
59
  # Returns the number of folds to partition the documents into.
47
60
  #
48
61
  # @return [Fixnum]
@@ -50,6 +63,24 @@ module CrossValidation
50
63
  @k ||= percentage ? (documents.size * percentage) : folds
51
64
  end
52
65
 
66
+ # Checks if all of the required run parameters are set.
67
+ #
68
+ # @return [Boolean]
69
+ def valid?
70
+ @errors = []
71
+ @critical_keys.each do |k|
72
+ any_error = public_send(k).nil?
73
+ @errors << k if any_error
74
+ end
75
+
76
+ @errors.size == 0
77
+ end
78
+
79
+ # @see #valid?
80
+ def invalid?
81
+ !valid?
82
+ end
83
+
53
84
  # Performs k-fold cross-validation and returns a confusion matrix.
54
85
  #
55
86
  # The algorithm is as follows (Mitchell, 1997, p147):
@@ -61,24 +92,23 @@ module CrossValidation
61
92
  # classify(partitions[i])
62
93
  # output confusion matrix
63
94
  #
95
+ # @raise [ArgumentError] if the runner is missing required attributes
96
+ # @return [ConfusionMatrix]
64
97
  def run
65
- partitions = documents.each_slice(k).to_a
98
+ fail_if_invalid
99
+
100
+ partitions = Partitioner.subset(documents, k)
66
101
 
67
102
  results = partitions.map.with_index do |part, i|
68
- # Array#rotate puts the element i first, so all we have to do is rotate
69
- # then remove that element to get the training set. Array#drop does not
70
- # mutate the original array either. Array#flatten is needed to coalesce
71
- # our list of lists into one list again.
72
- training_samples = partitions.rotate(i).drop(1).flatten
103
+ training_samples = Partitioner.exclude_index(documents, i)
73
104
 
74
105
  classifier_instance = classifier.call()
75
106
 
76
- # train it
77
- training_samples.each { |doc| training.call(classifier_instance, doc) }
107
+ train(classifier_instance, training_samples)
78
108
 
79
109
  # fetch confusion keys
80
110
  part.each do |x|
81
- prediction = classifying.call(classifier_instance, fetch_sample_value.call(x))
111
+ prediction = classify(classifier_instance, x)
82
112
  matrix.store(prediction, fetch_sample_class.call(x))
83
113
  end
84
114
  end
@@ -91,5 +121,23 @@ module CrossValidation
91
121
  def self.create
92
122
  new.tap { |r| yield(r) }
93
123
  end
124
+
125
+ private
126
+
127
+ def fail_if_invalid
128
+ return nil if valid?
129
+ msg = "The following attribute(s) must be specified: #{errors.join(', ')}"
130
+ fail ArgumentError, msg
131
+ end
132
+
133
+ def train(classifier_instance, samples)
134
+ samples.each do |doc|
135
+ training.call(classifier_instance, doc)
136
+ end
137
+ end
138
+
139
+ def classify(classifier_instance, sample)
140
+ classifying.call(classifier_instance, fetch_sample_value.call(sample))
141
+ end
94
142
  end
95
143
  end
@@ -0,0 +1,15 @@
1
+ module CrossValidation
2
+ # Represents a datum and its class (e.g., "spam").
3
+ #
4
+ # This is an optional data structure that simplifies definining training
5
+ # methods in cross-validation runs.
6
+ Sample = Struct.new(:klass, :value)
7
+
8
+ # Converts an array of +[class, value]+ into a `Sample` object.
9
+ #
10
+ # @param [Array] tuple
11
+ # @return [Sample]
12
+ def self.Sample(tuple)
13
+ Sample.new(tuple.fetch(0), tuple.fetch(1))
14
+ end
15
+ end
@@ -0,0 +1,24 @@
1
+ # A toy classifier. As long as you can tell the CrossValidation gem how to
2
+ # invoke your training and classifying methods, then you can do whatever you
3
+ # want in your classifier.
4
+ class SpamClassifier
5
+ def train(klass, document)
6
+ # don't bother, we're that good (in reality, you should probably do some
7
+ # work here)
8
+ end
9
+
10
+ def classify(document)
11
+ document =~ /viagra/ ? :spam : :ham
12
+ end
13
+
14
+ # Dummy method for use in testing confusion matrices. Used to determine
15
+ # whether a class is a true positive|negative or a false positive|negative.
16
+ # This is used when configuring a confusion matrix.
17
+ def self.keys_for(expected, actual)
18
+ if expected == :spam
19
+ actual == :spam ? :tp : :fp
20
+ elsif expected == :ham
21
+ actual == :ham ? :tn : :fn
22
+ end
23
+ end
24
+ end
@@ -1,4 +1,5 @@
1
1
  require_relative 'test_helper'
2
+ require_relative 'support/spam_classifier'
2
3
  require_relative '../lib/cross_validation/confusion_matrix'
3
4
 
4
5
  class TestConfusionMatrix < MiniTest::Unit::TestCase
@@ -7,7 +8,7 @@ class TestConfusionMatrix < MiniTest::Unit::TestCase
7
8
  end
8
9
 
9
10
  def setup
10
- @mat = CrossValidation::ConfusionMatrix.new(method(:keys_for))
11
+ @mat = CrossValidation::ConfusionMatrix.new(SpamClassifier.method(:keys_for))
11
12
  end
12
13
 
13
14
  def test_true_positives
@@ -31,7 +32,7 @@ class TestConfusionMatrix < MiniTest::Unit::TestCase
31
32
  end
32
33
 
33
34
  def test_store_raises_index_error_on_bad_key
34
- bad_keys_for = ->(actual, expected) { :bad }
35
+ bad_keys_for = ->(expected, actual) { :bad }
35
36
  mat = CrossValidation::ConfusionMatrix.new(bad_keys_for)
36
37
  assert_raises IndexError do
37
38
  mat.store(:ham, :spam)
@@ -1,10 +1 @@
1
1
  require 'minitest/autorun'
2
-
3
- # Dummy method for use in testing confusion matrices.
4
- def keys_for(actual, expected)
5
- if actual == :spam
6
- expected == :spam ? :tp : :fn
7
- elsif actual == :ham
8
- expected == :ham ? :tn : :fp
9
- end
10
- end
@@ -0,0 +1,28 @@
1
+ require_relative 'test_helper'
2
+ require_relative '../lib/cross_validation/partitioner'
3
+
4
+ class TestPartitioner < MiniTest::Unit::TestCase
5
+ def setup
6
+ @docs = %w(foo bar baz qux)
7
+ @p = CrossValidation::Partitioner
8
+ end
9
+
10
+ def test_create_equal_subsets_returns_equal_subsets
11
+ subsets = @p.subset(@docs, 2)
12
+
13
+ assert_equal %w(foo bar), subsets.first
14
+ assert_equal %w(baz qux), subsets.last
15
+ end
16
+
17
+ def test_create_equal_subsets_prevents_unequal_subsets
18
+ e = assert_raises ArgumentError do
19
+ @p.subset(@docs, 3)
20
+ end
21
+ assert_equal "Can't create equal subsets when k=3", e.message
22
+ end
23
+
24
+ def test_exclude_by_index
25
+ samples = @p.exclude_index(@docs, 1)
26
+ assert_equal %w(baz qux foo), samples
27
+ end
28
+ end
@@ -1,24 +1,9 @@
1
1
  require_relative 'test_helper'
2
+ require_relative 'support/spam_classifier'
2
3
  require_relative '../lib/cross_validation/confusion_matrix'
4
+ require_relative '../lib/cross_validation/sample'
3
5
  require_relative '../lib/cross_validation/runner'
4
6
 
5
- # A stupid classifier
6
- class SpamClassifier
7
- def train(klass, document)
8
- # don't bother, we're that good (in reality, you should probably do some
9
- # work here)
10
- end
11
-
12
- def classify(document)
13
- document =~ /viagra/ ? :spam : :ham
14
- end
15
- end
16
-
17
- # We just need to associate a class with a value. Feel free to use whatever
18
- # data structure you like -- this is only used in user-defined training and
19
- # classifying closures.
20
- Sample = Struct.new(:klass, :value)
21
-
22
7
  # Asserts the DSL's getter and setters work.
23
8
  def check_dsl(attribute, value)
24
9
  runner = CrossValidation::Runner.create { |r|
@@ -28,13 +13,17 @@ def check_dsl(attribute, value)
28
13
  define_method("test_#{attribute}_getter") {
29
14
  assert_equal :value, runner.public_send(attribute)
30
15
  }
16
+
17
+ define_method("test_runner_is_invalid_with_only_#{attribute}_set") {
18
+ assert runner.invalid?
19
+ }
31
20
  end
32
21
 
33
22
  class TestRunner < MiniTest::Unit::TestCase
34
23
  def setup
35
24
  tpl = ['Buy some...', 'Would you like some...']
36
- @spam = tpl.map { |pfx| Sample.new(:spam, pfx + 'viagra!') }
37
- @ham = tpl.map { |pfx| Sample.new(:ham, pfx + 'penicillin!') }
25
+ @spam = tpl.map { |pfx| CrossValidation::Sample.new(:spam, pfx + 'viagra!') }
26
+ @ham = tpl.map { |pfx| CrossValidation::Sample.new(:ham, pfx + 'penicillin!') }
38
27
  @corpus = @spam + @ham
39
28
  @corpus *= 25 # 100 is easier to deal with
40
29
  end
@@ -44,9 +33,7 @@ class TestRunner < MiniTest::Unit::TestCase
44
33
  r.documents = @corpus
45
34
  r.folds = 10
46
35
  r.classifier = lambda { SpamClassifier.new }
47
- r.fetch_sample_class = lambda { |sample| sample.klass }
48
- r.fetch_sample_value = lambda { |sample| sample.value }
49
- r.matrix = CrossValidation::ConfusionMatrix.new(method(:keys_for))
36
+ r.matrix = CrossValidation::ConfusionMatrix.new(SpamClassifier.method(:keys_for))
50
37
  r.training = lambda { |classifier, doc|
51
38
  classifier.train doc.klass, doc.value
52
39
  }
@@ -83,4 +70,12 @@ class TestRunner < MiniTest::Unit::TestCase
83
70
  ].each do |attribute|
84
71
  check_dsl(attribute, :foo)
85
72
  end
73
+
74
+ def test_invalid_runner_raises_error
75
+ runner = CrossValidation::Runner.create {}
76
+ exception = assert_raises ArgumentError do
77
+ runner.run
78
+ end
79
+ assert_match(/must be specified/, exception.message)
80
+ end
86
81
  end
@@ -0,0 +1,29 @@
1
+ require_relative 'test_helper'
2
+ require_relative '../lib/cross_validation/sample'
3
+
4
+ class TestSample < MiniTest::Unit::TestCase
5
+ def setup
6
+ @sample = CrossValidation::Sample.new(:spam, :spammy_msg)
7
+ end
8
+
9
+ def test_klass
10
+ assert_equal :spam, @sample.klass
11
+ end
12
+
13
+ def test_value
14
+ assert_equal :spammy_msg, @sample.value
15
+ end
16
+
17
+ def test_casting_a_tuple_to_sample
18
+ tuple = [:ham, :some_value]
19
+ sample = CrossValidation::Sample(tuple)
20
+ assert_equal :ham, sample.klass
21
+ assert_equal :some_value, sample.value
22
+ end
23
+
24
+ def test_casting_an_incomplete_tuple_to_sample_fails
25
+ assert_raises IndexError do
26
+ CrossValidation::Sample([])
27
+ end
28
+ end
29
+ end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: cross_validation
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Jon-Michael Deldin
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2013-04-06 00:00:00.000000000 Z
11
+ date: 2013-04-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -40,10 +40,15 @@ files:
40
40
  - cross_validation.gemspec
41
41
  - lib/cross_validation.rb
42
42
  - lib/cross_validation/confusion_matrix.rb
43
+ - lib/cross_validation/partitioner.rb
43
44
  - lib/cross_validation/runner.rb
45
+ - lib/cross_validation/sample.rb
46
+ - test/support/spam_classifier.rb
44
47
  - test/test_confusion_matrix.rb
45
48
  - test/test_helper.rb
49
+ - test/test_partitioner.rb
46
50
  - test/test_runner.rb
51
+ - test/test_sample.rb
47
52
  homepage: https://github.com/jmdeldin/cross_validation
48
53
  licenses: []
49
54
  metadata: {}
@@ -68,6 +73,9 @@ signing_key:
68
73
  specification_version: 4
69
74
  summary: Performs k-fold cross-validation on machine learning classifiers.
70
75
  test_files:
76
+ - test/support/spam_classifier.rb
71
77
  - test/test_confusion_matrix.rb
72
78
  - test/test_helper.rb
79
+ - test/test_partitioner.rb
73
80
  - test/test_runner.rb
81
+ - test/test_sample.rb