cross_validation 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: e36ebf38a97ebf8665474186f2c1ae3c82a3855d
4
- data.tar.gz: 770c131f2d8e3cf359e7e26450cb7a24c704239c
3
+ metadata.gz: 5f5fee4f5040fd1ad3562518031b850f7cf8c4b0
4
+ data.tar.gz: 7076cc8ff37d8967c84f04a493a5dff8729e0202
5
5
  SHA512:
6
- metadata.gz: 34ce7b4484db03a2d09aeb2b1f40c586a84a63b6ba22d355220ee6cba8bf588f3c8a2ac867a82524d43aa9fba06863f4f44ca09af8a17cb7f56409d9105dce7f
7
- data.tar.gz: d55546305d845c1f2c977f825f2d69a77b32369c3afc13b5262dd170b523f4efffa6f7fa9f6d9f7d9215b6b008156ff1ae2cb13993625268bc8d3e1f0b4b2f4c
6
+ metadata.gz: f27ea09189c51f97db89aaee88dee9cca8adefbc503fe596f91297ed9acc6273be0c3414dfd6b1004b9e32ff2cb98cc22edafaa3484ddf062e965a609bdff39b
7
+ data.tar.gz: 04e0bf2d0863102dadddcb053d053c8d4dd5bbb0098acc58b3205a09438d56c5c3077aa3529302017f709f1bab8cf7a7440b690d047b6bd1cb36277daf8eb810
data/README.md CHANGED
@@ -4,7 +4,9 @@
4
4
  [![Code Climate](https://codeclimate.com/github/jmdeldin/cross_validation.png)](https://codeclimate.com/github/jmdeldin/cross_validation)
5
5
 
6
6
  This gem provides a k-fold cross-validation routine and confusion matrix
7
- for evaluating machine learning classifiers.
7
+ for evaluating machine learning classifiers. See [below](#usage) for
8
+ usage or jump to the
9
+ [documentation](http://rubydoc.info/github/jmdeldin/cross_validation/frames).
8
10
 
9
11
  ## Installation
10
12
 
@@ -22,19 +24,89 @@ Or install it yourself as:
22
24
 
23
25
  ## Usage
24
26
 
25
- Cross-validation:
26
-
27
- Confusion-matrix:
28
-
29
-
30
- ## Contributing
31
-
32
- 1. Fork it
33
- 2. Create your feature branch (`git checkout -b my-new-feature`)
34
- 3. Commit your changes (`git commit -am 'Add some feature'`)
35
- 4. Push to the branch (`git push origin my-new-feature`)
36
- 5. Create new Pull Request
37
-
38
- ## Questions
39
-
40
- Send me an email, `dev@jmdeldin.com`
27
+ To cross-validate your classifier, you need to configure a run as
28
+ follows:
29
+
30
+ ```ruby
31
+ require 'cross_validation'
32
+
33
+ runner = CrossValidation::Runner.create do |r|
34
+ r.documents = my_array_of_documents
35
+ r.folds = 10
36
+ # or if you'd rather test on 10%
37
+ # r.percentage = 0.1
38
+ r.classifier = lambda { SpamClassifier.new }
39
+ r.fetch_sample_class = lambda { |sample| sample.klass }
40
+ r.fetch_sample_value = lambda { |sample| sample.value }
41
+ r.matrix = CrossValidation::ConfusionMatrix.new(method(:keys_for))
42
+ r.training = lambda { |classifier, doc|
43
+ classifier.train doc.klass, doc.value
44
+ }
45
+ r.classifying = lambda { |classifier, doc|
46
+ classifier.classify doc
47
+ }
48
+ end
49
+ ```
50
+
51
+ With the run configured, just invoke `#run` to return a confusion matrix:
52
+
53
+ ```ruby
54
+ mat = runner.run
55
+ ```
56
+
57
+ With a confusion matrix in hand, you can compute many statistics about
58
+ your classifier:
59
+
60
+ - `mat.accuracy`
61
+ - `mat.f1`
62
+ - `mat.fscore(beta)`
63
+ - `mat.precision`
64
+ - `mat.recall`
65
+
66
+ Please see the
67
+ [respective documentation](http://rubydoc.info/github/jmdeldin/cross_validation/CrossValidation/ConfusionMatrix)
68
+ for each method for more details.
69
+
70
+ ### Defining `keys_for`
71
+
72
+ The ConfusionMatrix class requires a `keys_for` `Proc` that returns a
73
+ symbol. In this method, you specify what constitutes a true positive
74
+ (`:tp`), true negative (`:tn`), false positive (`:fp`), and false
75
+ negative (`:fn`). For example, in spam classification, you can construct
76
+ the following table to write the keys_for method:
77
+
78
+ actual
79
+ +---------------------------------
80
+ expected | correct | not correct
81
+ ----------+----------------+----------------
82
+ spam | true positive | false positive
83
+ ham | true negative | false negative
84
+
85
+ You can then implement this table with nested hashes or just a few
86
+ conditionals:
87
+
88
+ ```ruby
89
+ def keys_for(expected, actual)
90
+ if expected == :spam
91
+ actual == :spam ? :tp : :fp
92
+ elsif expected == :ham
93
+ actual == :ham ? :tn : :fn
94
+ end
95
+ end
96
+ ```
97
+
98
+ Once you have your `keys_for` method implemented, pass it into the
99
+ ConfusionMatrix with `method(:keys_for)`, or if it's a class-method,
100
+ `MyClass.method(:keys_for)`. (You can also implement the method as a
101
+ lambda.)
102
+
103
+ ## Roadmap
104
+
105
+ For v1.0:
106
+
107
+ - Implement configurable, parallel cross-validation
108
+ - Include more complete examples
109
+
110
+ ## Author
111
+
112
+ Jon-Michael Deldin, `dev@jmdeldin.com`
@@ -1,5 +1,9 @@
1
1
  $LOAD_PATH.unshift File.dirname(__FILE__)
2
2
 
3
3
  module CrossValidation
4
- VERSION = '0.0.1'
4
+ VERSION = '0.0.2'
5
+ end
6
+
7
+ %w(confusion_matrix runner).each do |fn|
8
+ require File.join('cross_validation', fn)
5
9
  end
@@ -34,7 +34,7 @@ module CrossValidation
34
34
  # @param [Object] truth The known, expected value
35
35
  # @return [self]
36
36
  def store(actual, truth)
37
- key = @keys_for.call(actual, truth)
37
+ key = @keys_for.call(truth, actual)
38
38
 
39
39
  if @values.key?(key)
40
40
  @values[key] += 1
@@ -83,6 +83,8 @@ module CrossValidation
83
83
  end
84
84
 
85
85
  # Returns the classifier's error
86
+ #
87
+ # @return [Float]
86
88
  def error
87
89
  1.0 - accuracy()
88
90
  end
@@ -0,0 +1,34 @@
1
+ module CrossValidation
2
+ # Provides helper methods for data partitioning.
3
+ #
4
+ module Partitioner
5
+
6
+ # Splits the array into +k+-sized subsets.
7
+ #
8
+ # For example, calling this method for the array +%w(foo bar baz qux)+
9
+ # with +k=2+ results in an array of arrays: +[[foo, bar], [baz, qux]]+.
10
+ #
11
+ # @param [Array] ary documents to work with
12
+ # @param [Fixnum] k size of each subset
13
+ # @return [Array] array of arrays
14
+ # @raise [ArgumentError] if the length of the documents array is not
15
+ # evenly divisible by k
16
+ def self.subset(ary, k)
17
+ if ary.length % k > 0
18
+ fail ArgumentError, "Can't create equal subsets when k=#{k}"
19
+ end
20
+
21
+ ary.each_slice(k).to_a
22
+ end
23
+
24
+ # Returns a flattened copy of the original array without an element at
25
+ # index +i+.
26
+ #
27
+ # @param [Array] ary subsets to work with (e.g., array of arrays)
28
+ # @param [Fixnum] i index to remove
29
+ # @return [Array]
30
+ def self.exclude_index(ary, i)
31
+ ary.rotate(i).drop(1).flatten
32
+ end
33
+ end
34
+ end
@@ -1,4 +1,6 @@
1
1
  require_relative '../cross_validation'
2
+ require_relative 'partitioner'
3
+ require_relative 'sample'
2
4
 
3
5
  module CrossValidation
4
6
  class Runner
@@ -43,6 +45,17 @@ module CrossValidation
43
45
  # document and should return the document's class.
44
46
  attr_accessor :fetch_sample_class
45
47
 
48
+ # @return [Array] Array of which attributes are empty
49
+ attr_reader :errors
50
+
51
+ def initialize
52
+ @fetch_sample_value = lambda { |sample| sample.value }
53
+ @fetch_sample_class = lambda { |sample| sample.klass }
54
+
55
+ @critical_keys = [:documents, :classifier, :matrix, :training,
56
+ :classifying, :fetch_sample_value, :fetch_sample_class]
57
+ end
58
+
46
59
  # Returns the number of folds to partition the documents into.
47
60
  #
48
61
  # @return [Fixnum]
@@ -50,6 +63,24 @@ module CrossValidation
50
63
  @k ||= percentage ? (documents.size * percentage) : folds
51
64
  end
52
65
 
66
+ # Checks if all of the required run parameters are set.
67
+ #
68
+ # @return [Boolean]
69
+ def valid?
70
+ @errors = []
71
+ @critical_keys.each do |k|
72
+ any_error = public_send(k).nil?
73
+ @errors << k if any_error
74
+ end
75
+
76
+ @errors.size == 0
77
+ end
78
+
79
+ # @see #valid?
80
+ def invalid?
81
+ !valid?
82
+ end
83
+
53
84
  # Performs k-fold cross-validation and returns a confusion matrix.
54
85
  #
55
86
  # The algorithm is as follows (Mitchell, 1997, p147):
@@ -61,24 +92,23 @@ module CrossValidation
61
92
  # classify(partitions[i])
62
93
  # output confusion matrix
63
94
  #
95
+ # @raise [ArgumentError] if the runner is missing required attributes
96
+ # @return [ConfusionMatrix]
64
97
  def run
65
- partitions = documents.each_slice(k).to_a
98
+ fail_if_invalid
99
+
100
+ partitions = Partitioner.subset(documents, k)
66
101
 
67
102
  results = partitions.map.with_index do |part, i|
68
- # Array#rotate puts the element i first, so all we have to do is rotate
69
- # then remove that element to get the training set. Array#drop does not
70
- # mutate the original array either. Array#flatten is needed to coalesce
71
- # our list of lists into one list again.
72
- training_samples = partitions.rotate(i).drop(1).flatten
103
+ training_samples = Partitioner.exclude_index(documents, i)
73
104
 
74
105
  classifier_instance = classifier.call()
75
106
 
76
- # train it
77
- training_samples.each { |doc| training.call(classifier_instance, doc) }
107
+ train(classifier_instance, training_samples)
78
108
 
79
109
  # fetch confusion keys
80
110
  part.each do |x|
81
- prediction = classifying.call(classifier_instance, fetch_sample_value.call(x))
111
+ prediction = classify(classifier_instance, x)
82
112
  matrix.store(prediction, fetch_sample_class.call(x))
83
113
  end
84
114
  end
@@ -91,5 +121,23 @@ module CrossValidation
91
121
  def self.create
92
122
  new.tap { |r| yield(r) }
93
123
  end
124
+
125
+ private
126
+
127
+ def fail_if_invalid
128
+ return nil if valid?
129
+ msg = "The following attribute(s) must be specified: #{errors.join(', ')}"
130
+ fail ArgumentError, msg
131
+ end
132
+
133
+ def train(classifier_instance, samples)
134
+ samples.each do |doc|
135
+ training.call(classifier_instance, doc)
136
+ end
137
+ end
138
+
139
+ def classify(classifier_instance, sample)
140
+ classifying.call(classifier_instance, fetch_sample_value.call(sample))
141
+ end
94
142
  end
95
143
  end
@@ -0,0 +1,15 @@
1
+ module CrossValidation
2
+ # Represents a datum and its class (e.g., "spam").
3
+ #
4
+ # This is an optional data structure that simplifies definining training
5
+ # methods in cross-validation runs.
6
+ Sample = Struct.new(:klass, :value)
7
+
8
+ # Converts an array of +[class, value]+ into a `Sample` object.
9
+ #
10
+ # @param [Array] tuple
11
+ # @return [Sample]
12
+ def self.Sample(tuple)
13
+ Sample.new(tuple.fetch(0), tuple.fetch(1))
14
+ end
15
+ end
@@ -0,0 +1,24 @@
1
+ # A toy classifier. As long as you can tell the CrossValidation gem how to
2
+ # invoke your training and classifying methods, then you can do whatever you
3
+ # want in your classifier.
4
+ class SpamClassifier
5
+ def train(klass, document)
6
+ # don't bother, we're that good (in reality, you should probably do some
7
+ # work here)
8
+ end
9
+
10
+ def classify(document)
11
+ document =~ /viagra/ ? :spam : :ham
12
+ end
13
+
14
+ # Dummy method for use in testing confusion matrices. Used to determine
15
+ # whether a class is a true positive|negative or a false positive|negative.
16
+ # This is used when configuring a confusion matrix.
17
+ def self.keys_for(expected, actual)
18
+ if expected == :spam
19
+ actual == :spam ? :tp : :fp
20
+ elsif expected == :ham
21
+ actual == :ham ? :tn : :fn
22
+ end
23
+ end
24
+ end
@@ -1,4 +1,5 @@
1
1
  require_relative 'test_helper'
2
+ require_relative 'support/spam_classifier'
2
3
  require_relative '../lib/cross_validation/confusion_matrix'
3
4
 
4
5
  class TestConfusionMatrix < MiniTest::Unit::TestCase
@@ -7,7 +8,7 @@ class TestConfusionMatrix < MiniTest::Unit::TestCase
7
8
  end
8
9
 
9
10
  def setup
10
- @mat = CrossValidation::ConfusionMatrix.new(method(:keys_for))
11
+ @mat = CrossValidation::ConfusionMatrix.new(SpamClassifier.method(:keys_for))
11
12
  end
12
13
 
13
14
  def test_true_positives
@@ -31,7 +32,7 @@ class TestConfusionMatrix < MiniTest::Unit::TestCase
31
32
  end
32
33
 
33
34
  def test_store_raises_index_error_on_bad_key
34
- bad_keys_for = ->(actual, expected) { :bad }
35
+ bad_keys_for = ->(expected, actual) { :bad }
35
36
  mat = CrossValidation::ConfusionMatrix.new(bad_keys_for)
36
37
  assert_raises IndexError do
37
38
  mat.store(:ham, :spam)
@@ -1,10 +1 @@
1
1
  require 'minitest/autorun'
2
-
3
- # Dummy method for use in testing confusion matrices.
4
- def keys_for(actual, expected)
5
- if actual == :spam
6
- expected == :spam ? :tp : :fn
7
- elsif actual == :ham
8
- expected == :ham ? :tn : :fp
9
- end
10
- end
@@ -0,0 +1,28 @@
1
+ require_relative 'test_helper'
2
+ require_relative '../lib/cross_validation/partitioner'
3
+
4
+ class TestPartitioner < MiniTest::Unit::TestCase
5
+ def setup
6
+ @docs = %w(foo bar baz qux)
7
+ @p = CrossValidation::Partitioner
8
+ end
9
+
10
+ def test_create_equal_subsets_returns_equal_subsets
11
+ subsets = @p.subset(@docs, 2)
12
+
13
+ assert_equal %w(foo bar), subsets.first
14
+ assert_equal %w(baz qux), subsets.last
15
+ end
16
+
17
+ def test_create_equal_subsets_prevents_unequal_subsets
18
+ e = assert_raises ArgumentError do
19
+ @p.subset(@docs, 3)
20
+ end
21
+ assert_equal "Can't create equal subsets when k=3", e.message
22
+ end
23
+
24
+ def test_exclude_by_index
25
+ samples = @p.exclude_index(@docs, 1)
26
+ assert_equal %w(baz qux foo), samples
27
+ end
28
+ end
@@ -1,24 +1,9 @@
1
1
  require_relative 'test_helper'
2
+ require_relative 'support/spam_classifier'
2
3
  require_relative '../lib/cross_validation/confusion_matrix'
4
+ require_relative '../lib/cross_validation/sample'
3
5
  require_relative '../lib/cross_validation/runner'
4
6
 
5
- # A stupid classifier
6
- class SpamClassifier
7
- def train(klass, document)
8
- # don't bother, we're that good (in reality, you should probably do some
9
- # work here)
10
- end
11
-
12
- def classify(document)
13
- document =~ /viagra/ ? :spam : :ham
14
- end
15
- end
16
-
17
- # We just need to associate a class with a value. Feel free to use whatever
18
- # data structure you like -- this is only used in user-defined training and
19
- # classifying closures.
20
- Sample = Struct.new(:klass, :value)
21
-
22
7
  # Asserts the DSL's getter and setters work.
23
8
  def check_dsl(attribute, value)
24
9
  runner = CrossValidation::Runner.create { |r|
@@ -28,13 +13,17 @@ def check_dsl(attribute, value)
28
13
  define_method("test_#{attribute}_getter") {
29
14
  assert_equal :value, runner.public_send(attribute)
30
15
  }
16
+
17
+ define_method("test_runner_is_invalid_with_only_#{attribute}_set") {
18
+ assert runner.invalid?
19
+ }
31
20
  end
32
21
 
33
22
  class TestRunner < MiniTest::Unit::TestCase
34
23
  def setup
35
24
  tpl = ['Buy some...', 'Would you like some...']
36
- @spam = tpl.map { |pfx| Sample.new(:spam, pfx + 'viagra!') }
37
- @ham = tpl.map { |pfx| Sample.new(:ham, pfx + 'penicillin!') }
25
+ @spam = tpl.map { |pfx| CrossValidation::Sample.new(:spam, pfx + 'viagra!') }
26
+ @ham = tpl.map { |pfx| CrossValidation::Sample.new(:ham, pfx + 'penicillin!') }
38
27
  @corpus = @spam + @ham
39
28
  @corpus *= 25 # 100 is easier to deal with
40
29
  end
@@ -44,9 +33,7 @@ class TestRunner < MiniTest::Unit::TestCase
44
33
  r.documents = @corpus
45
34
  r.folds = 10
46
35
  r.classifier = lambda { SpamClassifier.new }
47
- r.fetch_sample_class = lambda { |sample| sample.klass }
48
- r.fetch_sample_value = lambda { |sample| sample.value }
49
- r.matrix = CrossValidation::ConfusionMatrix.new(method(:keys_for))
36
+ r.matrix = CrossValidation::ConfusionMatrix.new(SpamClassifier.method(:keys_for))
50
37
  r.training = lambda { |classifier, doc|
51
38
  classifier.train doc.klass, doc.value
52
39
  }
@@ -83,4 +70,12 @@ class TestRunner < MiniTest::Unit::TestCase
83
70
  ].each do |attribute|
84
71
  check_dsl(attribute, :foo)
85
72
  end
73
+
74
+ def test_invalid_runner_raises_error
75
+ runner = CrossValidation::Runner.create {}
76
+ exception = assert_raises ArgumentError do
77
+ runner.run
78
+ end
79
+ assert_match(/must be specified/, exception.message)
80
+ end
86
81
  end
@@ -0,0 +1,29 @@
1
+ require_relative 'test_helper'
2
+ require_relative '../lib/cross_validation/sample'
3
+
4
+ class TestSample < MiniTest::Unit::TestCase
5
+ def setup
6
+ @sample = CrossValidation::Sample.new(:spam, :spammy_msg)
7
+ end
8
+
9
+ def test_klass
10
+ assert_equal :spam, @sample.klass
11
+ end
12
+
13
+ def test_value
14
+ assert_equal :spammy_msg, @sample.value
15
+ end
16
+
17
+ def test_casting_a_tuple_to_sample
18
+ tuple = [:ham, :some_value]
19
+ sample = CrossValidation::Sample(tuple)
20
+ assert_equal :ham, sample.klass
21
+ assert_equal :some_value, sample.value
22
+ end
23
+
24
+ def test_casting_an_incomplete_tuple_to_sample_fails
25
+ assert_raises IndexError do
26
+ CrossValidation::Sample([])
27
+ end
28
+ end
29
+ end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: cross_validation
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Jon-Michael Deldin
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2013-04-06 00:00:00.000000000 Z
11
+ date: 2013-04-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rake
@@ -40,10 +40,15 @@ files:
40
40
  - cross_validation.gemspec
41
41
  - lib/cross_validation.rb
42
42
  - lib/cross_validation/confusion_matrix.rb
43
+ - lib/cross_validation/partitioner.rb
43
44
  - lib/cross_validation/runner.rb
45
+ - lib/cross_validation/sample.rb
46
+ - test/support/spam_classifier.rb
44
47
  - test/test_confusion_matrix.rb
45
48
  - test/test_helper.rb
49
+ - test/test_partitioner.rb
46
50
  - test/test_runner.rb
51
+ - test/test_sample.rb
47
52
  homepage: https://github.com/jmdeldin/cross_validation
48
53
  licenses: []
49
54
  metadata: {}
@@ -68,6 +73,9 @@ signing_key:
68
73
  specification_version: 4
69
74
  summary: Performs k-fold cross-validation on machine learning classifiers.
70
75
  test_files:
76
+ - test/support/spam_classifier.rb
71
77
  - test/test_confusion_matrix.rb
72
78
  - test/test_helper.rb
79
+ - test/test_partitioner.rb
73
80
  - test/test_runner.rb
81
+ - test/test_sample.rb