red-datasets 0.0.5 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +50 -4
- data/doc/text/news.md +8 -0
- data/lib/datasets/cifar.rb +5 -2
- data/lib/datasets/dataset.rb +1 -1
- data/lib/datasets/iris.rb +2 -1
- data/lib/datasets/metadata.rb +2 -1
- data/lib/datasets/mnist.rb +108 -0
- data/lib/datasets/penn-treebank.rb +65 -0
- data/lib/datasets/version.rb +1 -1
- data/lib/datasets/wikipedia.rb +3 -2
- data/lib/datasets.rb +2 -0
- data/test/helper.rb +1 -0
- data/test/test-cifar.rb +10 -0
- data/test/test-mnist.rb +111 -0
- data/test/test-penn-treebank.rb +59 -0
- data/test/test-wikipedia.rb +7 -2
- metadata +11 -5
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: 3b96f5bf8fb7d8d7280451086dda394b65c42023b15ae077167e2d320c4361c1
         | 
| 4 | 
            +
              data.tar.gz: 96f7936d62d70749f92d3bdd1d7ef2d79cfff3091e7dae8221d6a0537dbd6d7b
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: 859196aa39020d924fa7af4df6d96c110f41ac2b90a39dc89ed6935fc64e857b2bffb5776a366660ab61c55a96dd35b9bd6663ec23c7ee4249cae3103bc0a2aa
         | 
| 7 | 
            +
              data.tar.gz: b07ec53917af58e737058c504685d283850e072f0794c457bd961d39b9815c85b2fc2a9bed4de2a643675dc0e0f7bb2077b4c41b2c28c9c94f948a532baae6bb
         | 
    
        data/README.md
    CHANGED
    
    | @@ -18,7 +18,7 @@ You can use datasets easily because you can access each dataset with multiple wa | |
| 18 18 |  | 
| 19 19 | 
             
            ## Usage
         | 
| 20 20 |  | 
| 21 | 
            -
            Here is an example to access iris  | 
| 21 | 
            +
            Here is an example to access [Iris Data Set](https://archive.ics.uci.edu/ml/datasets/iris) by `#each`  or `Table#to_h` or `Table#fetch_values`.
         | 
| 22 22 |  | 
| 23 23 | 
             
            ```ruby
         | 
| 24 24 | 
             
            require "datasets"
         | 
| @@ -32,12 +32,40 @@ iris.each do |record| | |
| 32 32 | 
             
                 record.petal_width,
         | 
| 33 33 | 
             
                 record.class,
         | 
| 34 34 | 
             
              ]
         | 
| 35 | 
            -
              # [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]
         | 
| 36 | 
            -
              # [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"]
         | 
| 37 35 | 
             
            end
         | 
| 36 | 
            +
            # => [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]
         | 
| 37 | 
            +
            # => [4.9, 3.0, 1.4, 0.2, "Iris-setosa"]
         | 
| 38 | 
            +
              :
         | 
| 39 | 
            +
            # => [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            iris_hash = iris.to_table.to_h
         | 
| 43 | 
            +
            p iris_hash[:sepal_length]
         | 
| 44 | 
            +
            # => [5.1, 4.9, .. , 7.0, ..
         | 
| 45 | 
            +
            p iris_hash[:sepal_width]
         | 
| 46 | 
            +
            # => [3.5, 3.0, .. , 3.2, ..
         | 
| 47 | 
            +
            p iris_hash[:petal_length]
         | 
| 48 | 
            +
            # => [1.4, 1.4, .. , 4.7, ..
         | 
| 49 | 
            +
            p iris_hash[:petal_width]
         | 
| 50 | 
            +
            # => [0.2, 0.2, .. , 1.4, ..
         | 
| 51 | 
            +
            p iris_hash[:class]
         | 
| 52 | 
            +
            # => ["Iris-setosa", "Iris-setosa", .. , "Iris-versicolor", ..
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            iris_table = iris.to_table
         | 
| 56 | 
            +
            p iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
         | 
| 57 | 
            +
            # => [[5.1, 3.5, 1.4, 0.2],
         | 
| 58 | 
            +
                  [4.9, 3.0, 1.4, 0.2],
         | 
| 59 | 
            +
                  :
         | 
| 60 | 
            +
                  [7.0, 3.2, 4.7, 1.4],
         | 
| 61 | 
            +
                  :
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            p iris_table[:class]
         | 
| 64 | 
            +
            # => ["Iris-setosa", "Iris-setosa", .. , "Iris-versicolor", ..
         | 
| 38 65 | 
             
            ```
         | 
| 39 66 |  | 
| 40 | 
            -
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            Here is an example to access [The CIFAR-10/100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) by `#each`:
         | 
| 41 69 |  | 
| 42 70 | 
             
            **CIFAR-10**
         | 
| 43 71 |  | 
| @@ -73,6 +101,24 @@ cifar.each do |record| | |
| 73 101 | 
             
            end
         | 
| 74 102 | 
             
            ```
         | 
| 75 103 |  | 
| 104 | 
            +
            **MNIST**
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            ```ruby
         | 
| 107 | 
            +
            require "datasets"
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            mnist = Datasets::MNIST.new(type: :train)
         | 
| 110 | 
            +
            mnist.metadata
         | 
| 111 | 
            +
            #=> #<struct Datasets::Metadata name="MNIST-train", url="http://yann.lecun.com/exdb/mnist/", licenses=nil, description="a training set of 60,000 examples">
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            mnist.each do |record|
         | 
| 114 | 
            +
              p record.pixels
         | 
| 115 | 
            +
              # => [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, .....]
         | 
| 116 | 
            +
              p record.label
         | 
| 117 | 
            +
              # => 5
         | 
| 118 | 
            +
            end
         | 
| 119 | 
            +
            ```
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 76 122 | 
             
            ## License
         | 
| 77 123 |  | 
| 78 124 | 
             
            The MIT license. See `LICENSE.txt` for details.
         | 
    
        data/doc/text/news.md
    CHANGED
    
    
    
        data/lib/datasets/cifar.rb
    CHANGED
    
    | @@ -27,14 +27,17 @@ module Datasets | |
| 27 27 |  | 
| 28 28 | 
             
                def initialize(n_classes: 10, type: :train)
         | 
| 29 29 | 
             
                  unless [10, 100].include?(n_classes)
         | 
| 30 | 
            -
                     | 
| 30 | 
            +
                    message = "Please set n_classes 10 or 100: #{n_classes.inspect}"
         | 
| 31 | 
            +
                    raise ArgumentError, message
         | 
| 31 32 | 
             
                  end
         | 
| 32 33 | 
             
                  unless [:train, :test].include?(type)
         | 
| 33 | 
            -
                     | 
| 34 | 
            +
                    message = "Please set type :train or :test: #{type.inspect}"
         | 
| 35 | 
            +
                    raise ArgumentError, message
         | 
| 34 36 | 
             
                  end
         | 
| 35 37 |  | 
| 36 38 | 
             
                  super()
         | 
| 37 39 |  | 
| 40 | 
            +
                  @metadata.id = "cifar-#{n_classes}"
         | 
| 38 41 | 
             
                  @metadata.name = "CIFAR-#{n_classes}"
         | 
| 39 42 | 
             
                  @metadata.url = "https://www.cs.toronto.edu/~kriz/cifar.html"
         | 
| 40 43 | 
             
                  @metadata.description = "CIFAR-#{n_classes} is 32x32 image dataset"
         | 
    
        data/lib/datasets/dataset.rb
    CHANGED
    
    
    
        data/lib/datasets/iris.rb
    CHANGED
    
    
    
        data/lib/datasets/metadata.rb
    CHANGED
    
    
| @@ -0,0 +1,108 @@ | |
| 1 | 
            +
            require 'zlib'
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            require_relative "dataset"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class SetTypeError < StandardError; end
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            module Datasets
         | 
| 8 | 
            +
              class MNIST < Dataset
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                class Record < Struct.new(:data, :label)
         | 
| 11 | 
            +
                  def pixels
         | 
| 12 | 
            +
                    data.unpack("C*")
         | 
| 13 | 
            +
                  end
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  def to_h
         | 
| 16 | 
            +
                    hash = super
         | 
| 17 | 
            +
                    hash[:pixels] = pixels
         | 
| 18 | 
            +
                    hash
         | 
| 19 | 
            +
                  end
         | 
| 20 | 
            +
                end
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def initialize(type: :train)
         | 
| 23 | 
            +
                  unless [:train, :test].include?(type)
         | 
| 24 | 
            +
                    raise ArgumentError, "Please set type :train or :test: #{type.inspect}"
         | 
| 25 | 
            +
                  end
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                  super()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                  @metadata.id = "mnist-#{type}"
         | 
| 30 | 
            +
                  @metadata.name = "MNIST: #{type}"
         | 
| 31 | 
            +
                  @metadata.url = "http://yann.lecun.com/exdb/mnist/"
         | 
| 32 | 
            +
                  @type = type
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                  case type
         | 
| 35 | 
            +
                  when :train
         | 
| 36 | 
            +
                    @metadata.description = "a training set of 60,000 examples"
         | 
| 37 | 
            +
                  when :test
         | 
| 38 | 
            +
                    @metadata.description = "a test set of 10,000 examples"
         | 
| 39 | 
            +
                  end
         | 
| 40 | 
            +
                end
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def each(&block)
         | 
| 43 | 
            +
                  return to_enum(__method__) unless block_given?
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                  image_path = cache_dir_path + target_file(:image)
         | 
| 46 | 
            +
                  label_path = cache_dir_path + target_file(:label)
         | 
| 47 | 
            +
                  base_url = "http://yann.lecun.com/exdb/mnist/"
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                  unless image_path.exist?
         | 
| 50 | 
            +
                    download(image_path, base_url + target_file(:image))
         | 
| 51 | 
            +
                  end
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                  unless label_path.exist?
         | 
| 54 | 
            +
                    download(label_path, base_url + target_file(:label))
         | 
| 55 | 
            +
                  end
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                  open_data(image_path, label_path, &block)
         | 
| 58 | 
            +
                end
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                private
         | 
| 61 | 
            +
                def open_data(image_path, label_path, &block)
         | 
| 62 | 
            +
                  labels = parse_labels(label_path)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                  Zlib::GzipReader.open(image_path) do |f|
         | 
| 65 | 
            +
                    n_uint32s = 4
         | 
| 66 | 
            +
                    n_bytes = n_uint32s * 4
         | 
| 67 | 
            +
                    mnist_magic_number = 2051
         | 
| 68 | 
            +
                    magic, n_images, n_rows, n_cols = f.read(n_bytes).unpack("N*")
         | 
| 69 | 
            +
                    raise 'This is not MNIST image file' if magic != mnist_magic_number
         | 
| 70 | 
            +
                    n_images.times do |i|
         | 
| 71 | 
            +
                      data = f.read(n_rows * n_cols)
         | 
| 72 | 
            +
                      label = labels[i]
         | 
| 73 | 
            +
                      yield Record.new(data, label)
         | 
| 74 | 
            +
                    end
         | 
| 75 | 
            +
                  end
         | 
| 76 | 
            +
                end
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def target_file(data)
         | 
| 79 | 
            +
                  case @type
         | 
| 80 | 
            +
                  when :train
         | 
| 81 | 
            +
                    case data
         | 
| 82 | 
            +
                    when :image
         | 
| 83 | 
            +
                      "train-images-idx3-ubyte.gz"
         | 
| 84 | 
            +
                    when :label
         | 
| 85 | 
            +
                      "train-labels-idx1-ubyte.gz"
         | 
| 86 | 
            +
                    end
         | 
| 87 | 
            +
                  when :test
         | 
| 88 | 
            +
                    case data
         | 
| 89 | 
            +
                    when :image
         | 
| 90 | 
            +
                      "t10k-images-idx3-ubyte.gz"
         | 
| 91 | 
            +
                    when :label
         | 
| 92 | 
            +
                      "t10k-labels-idx1-ubyte.gz"
         | 
| 93 | 
            +
                    end
         | 
| 94 | 
            +
                  end
         | 
| 95 | 
            +
                end
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def parse_labels(file_path)
         | 
| 98 | 
            +
                  Zlib::GzipReader.open(file_path) do |f|
         | 
| 99 | 
            +
                    n_uint32s = 4
         | 
| 100 | 
            +
                    n_bytes = n_uint32s * 2
         | 
| 101 | 
            +
                    mnist_magic_number = 2049
         | 
| 102 | 
            +
                    magic, n_labels = f.read(n_bytes).unpack('N2')
         | 
| 103 | 
            +
                    raise 'This is not MNIST label file' if magic != mnist_magic_number
         | 
| 104 | 
            +
                    f.read(n_labels).unpack('C*')
         | 
| 105 | 
            +
                  end
         | 
| 106 | 
            +
                end
         | 
| 107 | 
            +
              end
         | 
| 108 | 
            +
            end
         | 
| @@ -0,0 +1,65 @@ | |
| 1 | 
            +
            require_relative "dataset"
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            module Datasets
         | 
| 4 | 
            +
              class PennTreebank < Dataset
         | 
| 5 | 
            +
                Record = Struct.new(:word, :id)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                DESCRIPTION = <<~DESC
         | 
| 8 | 
            +
                  `Penn Tree Bank <https://www.cis.upenn.edu/~treebank/>`_ is originally a
         | 
| 9 | 
            +
                  corpus of English sentences with linguistic structure annotations. This
         | 
| 10 | 
            +
                  function uses a variant distributed at
         | 
| 11 | 
            +
                  `https://github.com/wojzaremba/lstm <https://github.com/wojzaremba/lstm>`_,
         | 
| 12 | 
            +
                  which omits the annotation and splits the dataset into three parts:
         | 
| 13 | 
            +
                  training, validation, and test.
         | 
| 14 | 
            +
                DESC
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def initialize(type: :train)
         | 
| 17 | 
            +
                  valid_types = [:train, :test, :valid]
         | 
| 18 | 
            +
                  unless valid_types.include?(type)
         | 
| 19 | 
            +
                    valid_types_label = valid_types.collect(&:inspect).join(", ")
         | 
| 20 | 
            +
                    message = "Type must be one of [#{valid_types_label}]: #{type.inspect}"
         | 
| 21 | 
            +
                    raise ArgumentError, message
         | 
| 22 | 
            +
                  end
         | 
| 23 | 
            +
                  @type = type
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                  super()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                  @metadata.id = "penn-treebank-#{@type}"
         | 
| 28 | 
            +
                  @metadata.name = "Penn Treebank: #{@type}"
         | 
| 29 | 
            +
                  @metadata.description = DESCRIPTION
         | 
| 30 | 
            +
                  @metadata.url = "https://github.com/wojzaremba/lstm"
         | 
| 31 | 
            +
                  @metadata.licenses = ["Apache-2.0"]
         | 
| 32 | 
            +
                end
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def each(&block)
         | 
| 35 | 
            +
                  return to_enum(__method__) unless block_given?
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                  base_name = "ptb.#{@type}.txt"
         | 
| 38 | 
            +
                  data_path = cache_dir_path + base_name
         | 
| 39 | 
            +
                  unless data_path.exist?
         | 
| 40 | 
            +
                    base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data"
         | 
| 41 | 
            +
                    download(data_path, "#{base_url}/#{base_name}")
         | 
| 42 | 
            +
                  end
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                  parse_data(data_path, &block)
         | 
| 45 | 
            +
                end
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                private
         | 
| 48 | 
            +
                def parse_data(data_path)
         | 
| 49 | 
            +
                  index = 0
         | 
| 50 | 
            +
                  vocabulary = {}
         | 
| 51 | 
            +
                  File.open(data_path) do |f|
         | 
| 52 | 
            +
                    f.each_line do |line|
         | 
| 53 | 
            +
                      line.split.each do |word|
         | 
| 54 | 
            +
                        word = word.strip
         | 
| 55 | 
            +
                        unless vocabulary.key?(word)
         | 
| 56 | 
            +
                          vocabulary[word] = index
         | 
| 57 | 
            +
                          index += 1
         | 
| 58 | 
            +
                        end
         | 
| 59 | 
            +
                        yield(Record.new(word, vocabulary[word]))
         | 
| 60 | 
            +
                      end
         | 
| 61 | 
            +
                    end
         | 
| 62 | 
            +
                  end
         | 
| 63 | 
            +
                end
         | 
| 64 | 
            +
              end
         | 
| 65 | 
            +
            end
         | 
    
        data/lib/datasets/version.rb
    CHANGED
    
    
    
        data/lib/datasets/wikipedia.rb
    CHANGED
    
    | @@ -30,14 +30,15 @@ module Datasets | |
| 30 30 | 
             
                  super()
         | 
| 31 31 | 
             
                  @language = language
         | 
| 32 32 | 
             
                  @type = type
         | 
| 33 | 
            -
                  @metadata. | 
| 33 | 
            +
                  @metadata.id = "wikipedia-#{@language}-#{@type}"
         | 
| 34 | 
            +
                  @metadata.name = "Wikipedia #{@type} (#{@language})"
         | 
| 34 35 | 
             
                  @metadata.url = "https://dumps.wikimedia.org/"
         | 
| 35 36 | 
             
                  @metadata.licenses = [
         | 
| 36 37 | 
             
                    "CC-BY-SA-3.0",
         | 
| 37 38 | 
             
                    "CC-BY-SA-4.0",
         | 
| 38 39 | 
             
                    "GFDL-1.3-or-later",
         | 
| 39 40 | 
             
                  ]
         | 
| 40 | 
            -
                  @metadata.description = "Wikipedia #{@type}  | 
| 41 | 
            +
                  @metadata.description = "Wikipedia #{@type} in #{@language}"
         | 
| 41 42 | 
             
                end
         | 
| 42 43 |  | 
| 43 44 | 
             
                def each(&block)
         | 
    
        data/lib/datasets.rb
    CHANGED
    
    
    
        data/test/helper.rb
    CHANGED
    
    
    
        data/test/test-cifar.rb
    CHANGED
    
    | @@ -216,4 +216,14 @@ class CIFARTest < Test::Unit::TestCase | |
| 216 216 | 
             
                  end
         | 
| 217 217 | 
             
                end
         | 
| 218 218 | 
             
              end
         | 
| 219 | 
            +
             | 
| 220 | 
            +
              sub_test_case("invalid") do
         | 
| 221 | 
            +
                test("type") do
         | 
| 222 | 
            +
                  invalid_type = :invalid
         | 
| 223 | 
            +
                  message = "Please set type :train or :test: #{invalid_type.inspect}"
         | 
| 224 | 
            +
                  assert_raise(ArgumentError.new(message)) do
         | 
| 225 | 
            +
                    Datasets::CIFAR.new(type: invalid_type)
         | 
| 226 | 
            +
                  end
         | 
| 227 | 
            +
                end
         | 
| 228 | 
            +
              end
         | 
| 219 229 | 
             
            end
         | 
    
        data/test/test-mnist.rb
    ADDED
    
    | @@ -0,0 +1,111 @@ | |
| 1 | 
            +
            class MNISTTest < Test::Unit::TestCase
         | 
| 2 | 
            +
              include Helper::Sandbox
         | 
| 3 | 
            +
             | 
| 4 | 
            +
              sub_test_case("Normal") do
         | 
| 5 | 
            +
                def setup_data
         | 
| 6 | 
            +
                  setup_sandbox
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                  def @dataset.cache_dir_path
         | 
| 9 | 
            +
                    @cache_dir_path
         | 
| 10 | 
            +
                  end
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  def @dataset.cache_dir_path=(path)
         | 
| 13 | 
            +
                    @cache_dir_path = path
         | 
| 14 | 
            +
                  end
         | 
| 15 | 
            +
                  @dataset.cache_dir_path = @tmp_dir
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                  def @dataset.download(output_path, url)
         | 
| 18 | 
            +
                    image_magic_number = 2051
         | 
| 19 | 
            +
                    label_magic_number = 2049
         | 
| 20 | 
            +
                    n_image, image_size_x, image_size_y, label = 10, 28, 28, 1
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    Zlib::GzipWriter.open(output_path) do |gz|
         | 
| 23 | 
            +
                      if output_path.basename.to_s.include?("-images-")
         | 
| 24 | 
            +
                        image_data = ([image_magic_number, n_image]).pack('N2') +
         | 
| 25 | 
            +
                                     ([image_size_x,image_size_y]).pack('N2') +
         | 
| 26 | 
            +
                                     ([0] * image_size_x * image_size_y).pack("C*") * n_image
         | 
| 27 | 
            +
                        gz.puts(image_data)
         | 
| 28 | 
            +
                      else
         | 
| 29 | 
            +
                        label_data = ([label_magic_number, n_image]).pack('N2') +
         | 
| 30 | 
            +
                                     ([label] * n_image).pack("C*")
         | 
| 31 | 
            +
                        gz.puts(label_data)
         | 
| 32 | 
            +
                      end
         | 
| 33 | 
            +
                    end
         | 
| 34 | 
            +
                  end
         | 
| 35 | 
            +
                end
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def teardown
         | 
| 38 | 
            +
                  teardown_sandbox
         | 
| 39 | 
            +
                end
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                sub_test_case("train") do
         | 
| 42 | 
            +
                  def setup
         | 
| 43 | 
            +
                    @dataset = Datasets::MNIST.new(type: :train)
         | 
| 44 | 
            +
                    setup_data()
         | 
| 45 | 
            +
                  end
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                  test("#each") do
         | 
| 48 | 
            +
                    raw_dataset = @dataset.collect do |record|
         | 
| 49 | 
            +
                      {
         | 
| 50 | 
            +
                        :label => record.label,
         | 
| 51 | 
            +
                        :pixels => record.pixels
         | 
| 52 | 
            +
                      }
         | 
| 53 | 
            +
                    end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    assert_equal([
         | 
| 56 | 
            +
                                   {
         | 
| 57 | 
            +
                                     :label => 1,
         | 
| 58 | 
            +
                                     :pixels => [0] * 28 * 28
         | 
| 59 | 
            +
                                   }
         | 
| 60 | 
            +
                                 ] * 10,
         | 
| 61 | 
            +
                                 raw_dataset)
         | 
| 62 | 
            +
                  end
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                  test("#to_table") do
         | 
| 65 | 
            +
                    table_data = @dataset.to_table
         | 
| 66 | 
            +
                    assert_equal([[0] * 28 * 28] * 10,
         | 
| 67 | 
            +
                                 table_data[:pixels])
         | 
| 68 | 
            +
                  end
         | 
| 69 | 
            +
                end
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                sub_test_case("test") do
         | 
| 72 | 
            +
                  def setup
         | 
| 73 | 
            +
                    @dataset = Datasets::MNIST.new(type: :test)
         | 
| 74 | 
            +
                    setup_data()
         | 
| 75 | 
            +
                  end
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                  test("#each") do
         | 
| 78 | 
            +
                    raw_dataset = @dataset.collect do |record|
         | 
| 79 | 
            +
                      {
         | 
| 80 | 
            +
                        :label => record.label,
         | 
| 81 | 
            +
                        :pixels => record.pixels
         | 
| 82 | 
            +
                      }
         | 
| 83 | 
            +
                    end
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    assert_equal([
         | 
| 86 | 
            +
                                   {
         | 
| 87 | 
            +
                                     :label => 1,
         | 
| 88 | 
            +
                                     :pixels => [0] * 28 * 28
         | 
| 89 | 
            +
                                   }
         | 
| 90 | 
            +
                                 ] * 10,
         | 
| 91 | 
            +
                                 raw_dataset)
         | 
| 92 | 
            +
                  end
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  test("#to_table") do
         | 
| 95 | 
            +
                    table_data = @dataset.to_table
         | 
| 96 | 
            +
                    assert_equal([[0] * 28 * 28] * 10,
         | 
| 97 | 
            +
                                 table_data[:pixels])
         | 
| 98 | 
            +
                  end
         | 
| 99 | 
            +
                end
         | 
| 100 | 
            +
              end
         | 
| 101 | 
            +
             | 
| 102 | 
            +
              sub_test_case("Abnormal") do
         | 
| 103 | 
            +
                test("invalid type") do
         | 
| 104 | 
            +
                  invalid_type = :invalid
         | 
| 105 | 
            +
                  message = "Please set type :train or :test: #{invalid_type.inspect}"
         | 
| 106 | 
            +
                  assert_raise(ArgumentError.new(message)) do
         | 
| 107 | 
            +
                    Datasets::MNIST.new(type: invalid_type)
         | 
| 108 | 
            +
                  end
         | 
| 109 | 
            +
                end
         | 
| 110 | 
            +
              end
         | 
| 111 | 
            +
            end
         | 
| @@ -0,0 +1,59 @@ | |
| 1 | 
            +
            class PennTreebankTest < Test::Unit::TestCase
         | 
| 2 | 
            +
              def record(*args)
         | 
| 3 | 
            +
                Datasets::PennTreebank::Record.new(*args)
         | 
| 4 | 
            +
              end
         | 
| 5 | 
            +
             | 
| 6 | 
            +
              sub_test_case("type") do
         | 
| 7 | 
            +
                test("train") do
         | 
| 8 | 
            +
                  dataset = Datasets::PennTreebank.new(type: :train)
         | 
| 9 | 
            +
                  records = dataset.to_a
         | 
| 10 | 
            +
                  assert_equal([
         | 
| 11 | 
            +
                                 887521,
         | 
| 12 | 
            +
                                 record("aer", 0),
         | 
| 13 | 
            +
                                 record("<unk>", 25),
         | 
| 14 | 
            +
                               ],
         | 
| 15 | 
            +
                               [
         | 
| 16 | 
            +
                                 records.size,
         | 
| 17 | 
            +
                                 records[0],
         | 
| 18 | 
            +
                                 records[-1],
         | 
| 19 | 
            +
                               ])
         | 
| 20 | 
            +
                end
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                test("test") do
         | 
| 23 | 
            +
                  dataset = Datasets::PennTreebank.new(type: :test)
         | 
| 24 | 
            +
                  records = dataset.to_a
         | 
| 25 | 
            +
                  assert_equal([
         | 
| 26 | 
            +
                                 78669,
         | 
| 27 | 
            +
                                 record("no", 0),
         | 
| 28 | 
            +
                                 record("us", 953),
         | 
| 29 | 
            +
                               ],
         | 
| 30 | 
            +
                               [
         | 
| 31 | 
            +
                                 records.size,
         | 
| 32 | 
            +
                                 records[0],
         | 
| 33 | 
            +
                                 records[-1],
         | 
| 34 | 
            +
                               ])
         | 
| 35 | 
            +
                end
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                test("valid") do
         | 
| 38 | 
            +
                  dataset = Datasets::PennTreebank.new(type: :valid)
         | 
| 39 | 
            +
                  records = dataset.to_a
         | 
| 40 | 
            +
                  assert_equal([
         | 
| 41 | 
            +
                                 70390,
         | 
| 42 | 
            +
                                 record("consumers", 0),
         | 
| 43 | 
            +
                                 record("N", 28),
         | 
| 44 | 
            +
                               ],
         | 
| 45 | 
            +
                               [
         | 
| 46 | 
            +
                                 records.size,
         | 
| 47 | 
            +
                                 records[0],
         | 
| 48 | 
            +
                                 records[-1],
         | 
| 49 | 
            +
                               ])
         | 
| 50 | 
            +
                end
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                test("invalid") do
         | 
| 53 | 
            +
                  message = "Type must be one of [:train, :test, :valid]: :invalid"
         | 
| 54 | 
            +
                  assert_raise(ArgumentError.new(message)) do
         | 
| 55 | 
            +
                    Datasets::PennTreebank.new(type: :invalid)
         | 
| 56 | 
            +
                  end
         | 
| 57 | 
            +
                end
         | 
| 58 | 
            +
              end
         | 
| 59 | 
            +
            end
         | 
    
        data/test/test-wikipedia.rb
    CHANGED
    
    | @@ -83,13 +83,18 @@ class WikipediaTest < Test::Unit::TestCase | |
| 83 83 | 
             
                  end
         | 
| 84 84 |  | 
| 85 85 | 
             
                  sub_test_case("#metadata") do
         | 
| 86 | 
            -
                    test("# | 
| 86 | 
            +
                    test("#id") do
         | 
| 87 87 | 
             
                      assert_equal("wikipedia-ja-articles",
         | 
| 88 | 
            +
                                   @dataset.metadata.id)
         | 
| 89 | 
            +
                    end
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    test("#name") do
         | 
| 92 | 
            +
                      assert_equal("Wikipedia articles (ja)",
         | 
| 88 93 | 
             
                                   @dataset.metadata.name)
         | 
| 89 94 | 
             
                    end
         | 
| 90 95 |  | 
| 91 96 | 
             
                    test("#description") do
         | 
| 92 | 
            -
                      assert_equal("Wikipedia articles  | 
| 97 | 
            +
                      assert_equal("Wikipedia articles in ja",
         | 
| 93 98 | 
             
                                   @dataset.metadata.description)
         | 
| 94 99 | 
             
                    end
         | 
| 95 100 | 
             
                  end
         | 
    
        metadata
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: red-datasets
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0.0. | 
| 4 | 
            +
              version: 0.0.6
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - tomisuker
         | 
| @@ -9,7 +9,7 @@ authors: | |
| 9 9 | 
             
            autorequire: 
         | 
| 10 10 | 
             
            bindir: bin
         | 
| 11 11 | 
             
            cert_chain: []
         | 
| 12 | 
            -
            date: 2018- | 
| 12 | 
            +
            date: 2018-07-25 00:00:00.000000000 Z
         | 
| 13 13 | 
             
            dependencies:
         | 
| 14 14 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 15 15 | 
             
              name: bundler
         | 
| @@ -104,6 +104,8 @@ files: | |
| 104 104 | 
             
            - lib/datasets/downloader.rb
         | 
| 105 105 | 
             
            - lib/datasets/iris.rb
         | 
| 106 106 | 
             
            - lib/datasets/metadata.rb
         | 
| 107 | 
            +
            - lib/datasets/mnist.rb
         | 
| 108 | 
            +
            - lib/datasets/penn-treebank.rb
         | 
| 107 109 | 
             
            - lib/datasets/table.rb
         | 
| 108 110 | 
             
            - lib/datasets/version.rb
         | 
| 109 111 | 
             
            - lib/datasets/wikipedia.rb
         | 
| @@ -112,6 +114,8 @@ files: | |
| 112 114 | 
             
            - test/run-test.rb
         | 
| 113 115 | 
             
            - test/test-cifar.rb
         | 
| 114 116 | 
             
            - test/test-iris.rb
         | 
| 117 | 
            +
            - test/test-mnist.rb
         | 
| 118 | 
            +
            - test/test-penn-treebank.rb
         | 
| 115 119 | 
             
            - test/test-table.rb
         | 
| 116 120 | 
             
            - test/test-wikipedia.rb
         | 
| 117 121 | 
             
            homepage: https://github.com/red-data-tools/red-datasets
         | 
| @@ -139,9 +143,11 @@ signing_key: | |
| 139 143 | 
             
            specification_version: 4
         | 
| 140 144 | 
             
            summary: Red Datasets provides classes that provide common datasets such as iris dataset.
         | 
| 141 145 | 
             
            test_files:
         | 
| 142 | 
            -
            - test/test | 
| 146 | 
            +
            - test/run-test.rb
         | 
| 147 | 
            +
            - test/test-cifar.rb
         | 
| 143 148 | 
             
            - test/test-wikipedia.rb
         | 
| 149 | 
            +
            - test/test-iris.rb
         | 
| 144 150 | 
             
            - test/helper.rb
         | 
| 145 | 
            -
            - test/ | 
| 151 | 
            +
            - test/test-mnist.rb
         | 
| 146 152 | 
             
            - test/test-table.rb
         | 
| 147 | 
            -
            - test/test- | 
| 153 | 
            +
            - test/test-penn-treebank.rb
         |