easy_ml 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/app/models/easy_ml/model.rb +11 -2
- data/lib/easy_ml/core/model.rb +1 -0
- data/lib/easy_ml/core/model_core.rb +38 -38
- data/lib/easy_ml/core/models/xgboost_core.rb +1 -1
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -1
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +19 -11
- data/lib/easy_ml/support/file_rotate.rb +1 -1
- data/lib/easy_ml/version.rb +1 -1
- data/lib/easy_ml.rb +0 -1
- metadata +46 -5
- data/lib/easy_ml/trainer.rb +0 -92
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: c68319b0bcca7b83e1fb4a539a15b42ced62475f8c8ab7f6b9a7a164c3b02cea
         | 
| 4 | 
            +
              data.tar.gz: 1a6a83e77ff7723c590ab1238ce12d80aa6aa15bb5897b95a738c0a00760fb48
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: 8e1e3fdcbae41205892d47be90407b982efdce329b54b635ea2870d141297780ce6f166ab1769f3d531bf51906ff25c17a2d1b8d6a8aed0ea4e1e12052ff8eec
         | 
| 7 | 
            +
              data.tar.gz: 31b5b12f27734f44fa92015feb8e4ac5c868f38090c3244a90f15b83327427fc05f3d2a34446d1b64f11ba10903bb1b989e890b8a63381c3da2e5a1330b5f778
         | 
    
        data/app/models/easy_ml/model.rb
    CHANGED
    
    | @@ -1,9 +1,18 @@ | |
| 1 1 | 
             
            require_relative "../../../lib/easy_ml/core/model"
         | 
| 2 2 | 
             
            module EasyML
         | 
| 3 3 | 
             
              class Model < ActiveRecord::Base
         | 
| 4 | 
            -
                 | 
| 4 | 
            +
                if ActiveRecord::Base.connection.data_source_exists?("easy_ml_models")
         | 
| 5 | 
            +
                  include EasyML::Core::ModelCore
         | 
| 5 6 |  | 
| 6 | 
            -
             | 
| 7 | 
            +
                  self.table_name = "easy_ml_models"
         | 
| 8 | 
            +
                else
         | 
| 9 | 
            +
                  # Placeholder if the table doesn't exist (keeps the file quiet)
         | 
| 10 | 
            +
                  def self.table_ready?
         | 
| 11 | 
            +
                    false
         | 
| 12 | 
            +
                  end
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                  Rails.logger.info("Skipping EasyML::Model definition as the 'easy_ml_models' table doesn't exist.")
         | 
| 15 | 
            +
                end
         | 
| 7 16 |  | 
| 8 17 | 
             
                scope :live, -> { where(is_live: true) }
         | 
| 9 18 | 
             
                attribute :root_dir, :string
         | 
    
        data/lib/easy_ml/core/model.rb
    CHANGED
    
    
| @@ -21,30 +21,6 @@ module EasyML | |
| 21 21 | 
             
                    end
         | 
| 22 22 | 
             
                  end
         | 
| 23 23 |  | 
| 24 | 
            -
                  def dataset_is_a_dataset?
         | 
| 25 | 
            -
                    return if dataset.nil?
         | 
| 26 | 
            -
                    return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                    errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
         | 
| 29 | 
            -
                  end
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                  def validate_any_metrics?
         | 
| 32 | 
            -
                    return if metrics.any?
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                    errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
         | 
| 35 | 
            -
                  end
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                  def validate_metrics_for_task
         | 
| 38 | 
            -
                    nonsensical_metrics = metrics.select do |metric|
         | 
| 39 | 
            -
                      allowed_metrics.exclude?(metric)
         | 
| 40 | 
            -
                    end
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                    return unless nonsensical_metrics.any?
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                    errors.add(:metrics,
         | 
| 45 | 
            -
                               "cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
         | 
| 46 | 
            -
                  end
         | 
| 47 | 
            -
             | 
| 48 24 | 
             
                  def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
         | 
| 49 25 | 
             
                    if x_train.nil?
         | 
| 50 26 | 
             
                      dataset.refresh!
         | 
| @@ -55,16 +31,6 @@ module EasyML | |
| 55 31 | 
             
                    @is_fit = true
         | 
| 56 32 | 
             
                  end
         | 
| 57 33 |  | 
| 58 | 
            -
                  def decode_labels(ys, col: nil)
         | 
| 59 | 
            -
                    dataset.decode_labels(ys, col: col)
         | 
| 60 | 
            -
                  end
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                  def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
         | 
| 63 | 
            -
                    evaluator ||= self.evaluator
         | 
| 64 | 
            -
                    EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
         | 
| 65 | 
            -
                                                          evaluator: evaluator)
         | 
| 66 | 
            -
                  end
         | 
| 67 | 
            -
             | 
| 68 34 | 
             
                  def predict(xs)
         | 
| 69 35 | 
             
                    raise NotImplementedError, "Subclasses must implement predict method"
         | 
| 70 36 | 
             
                  end
         | 
| @@ -82,6 +48,16 @@ module EasyML | |
| 82 48 | 
             
                    save_model_file
         | 
| 83 49 | 
             
                  end
         | 
| 84 50 |  | 
| 51 | 
            +
                  def decode_labels(ys, col: nil)
         | 
| 52 | 
            +
                    dataset.decode_labels(ys, col: col)
         | 
| 53 | 
            +
                  end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                  def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
         | 
| 56 | 
            +
                    evaluator ||= self.evaluator
         | 
| 57 | 
            +
                    EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
         | 
| 58 | 
            +
                                                          evaluator: evaluator)
         | 
| 59 | 
            +
                  end
         | 
| 60 | 
            +
             | 
| 85 61 | 
             
                  def save_model_file
         | 
| 86 62 | 
             
                    raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
         | 
| 87 63 |  | 
| @@ -116,13 +92,13 @@ module EasyML | |
| 116 92 | 
             
                  end
         | 
| 117 93 |  | 
| 118 94 | 
             
                  def cleanup!
         | 
| 119 | 
            -
                    [ | 
| 95 | 
            +
                    [carrierwave_dir, model_dir].each do |dir|
         | 
| 120 96 | 
             
                      EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
         | 
| 121 97 | 
             
                    end
         | 
| 122 98 | 
             
                  end
         | 
| 123 99 |  | 
| 124 100 | 
             
                  def cleanup
         | 
| 125 | 
            -
                    [ | 
| 101 | 
            +
                    [carrierwave_dir, model_dir].each do |dir|
         | 
| 126 102 | 
             
                      EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
         | 
| 127 103 | 
             
                    end
         | 
| 128 104 | 
             
                  end
         | 
| @@ -133,7 +109,7 @@ module EasyML | |
| 133 109 |  | 
| 134 110 | 
             
                  private
         | 
| 135 111 |  | 
| 136 | 
            -
                  def  | 
| 112 | 
            +
                  def carrierwave_dir
         | 
| 137 113 | 
             
                    return unless file.path.present?
         | 
| 138 114 |  | 
| 139 115 | 
             
                    File.dirname(file.path).split("/")[0..-2].join("/")
         | 
| @@ -172,10 +148,34 @@ module EasyML | |
| 172 148 | 
             
                  end
         | 
| 173 149 |  | 
| 174 150 | 
             
                  def files_to_keep
         | 
| 175 | 
            -
                    Dir.glob(File.join( | 
| 151 | 
            +
                    Dir.glob(File.join(carrierwave_dir, "**/*")).select { |f| File.file?(f) }.sort_by do |filename|
         | 
| 176 152 | 
             
                      Time.parse(filename.split("/").last.gsub(/\D/, ""))
         | 
| 177 153 | 
             
                    end.reverse.take(5)
         | 
| 178 154 | 
             
                  end
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                  def dataset_is_a_dataset?
         | 
| 157 | 
            +
                    return if dataset.nil?
         | 
| 158 | 
            +
                    return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
         | 
| 161 | 
            +
                  end
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                  def validate_any_metrics?
         | 
| 164 | 
            +
                    return if metrics.any?
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
         | 
| 167 | 
            +
                  end
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                  def validate_metrics_for_task
         | 
| 170 | 
            +
                    nonsensical_metrics = metrics.select do |metric|
         | 
| 171 | 
            +
                      allowed_metrics.exclude?(metric)
         | 
| 172 | 
            +
                    end
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    return unless nonsensical_metrics.any?
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    errors.add(:metrics,
         | 
| 177 | 
            +
                               "cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
         | 
| 178 | 
            +
                  end
         | 
| 179 179 | 
             
                end
         | 
| 180 180 | 
             
              end
         | 
| 181 181 | 
             
            end
         | 
| @@ -155,7 +155,7 @@ module EasyML | |
| 155 155 | 
             
                      ys = ys.nil? ? nil : _preprocess(ys).flatten
         | 
| 156 156 | 
             
                      kwargs = { label: ys }.compact
         | 
| 157 157 | 
             
                      ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
         | 
| 158 | 
            -
                        dmat. | 
| 158 | 
            +
                        dmat.feature_names = column_names
         | 
| 159 159 | 
             
                      end
         | 
| 160 160 | 
             
                    end
         | 
| 161 161 |  | 
| @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            # lib/railtie/generators/migration/migration_generator.rb
         | 
| 2 1 | 
             
            require "rails/generators"
         | 
| 3 2 | 
             
            require "rails/generators/active_record/migration"
         | 
| 4 3 |  | 
| @@ -13,12 +12,7 @@ module EasyML | |
| 13 12 | 
             
                    source_root File.expand_path("../../templates/migration", __dir__)
         | 
| 14 13 |  | 
| 15 14 | 
             
                    # Define the migration name
         | 
| 16 | 
            -
                    desc "Generates  | 
| 17 | 
            -
             | 
| 18 | 
            -
                    # Define the migration name; can be customized if needed
         | 
| 19 | 
            -
                    def self.migration_name
         | 
| 20 | 
            -
                      "create_easy_ml_models"
         | 
| 21 | 
            -
                    end
         | 
| 15 | 
            +
                    desc "Generates migrations for EasyMLModel, Dataset, and TunerRun"
         | 
| 22 16 |  | 
| 23 17 | 
             
                    # Specify the next migration number
         | 
| 24 18 | 
             
                    def self.next_migration_number(dirname)
         | 
| @@ -31,10 +25,24 @@ module EasyML | |
| 31 25 | 
             
                      end
         | 
| 32 26 | 
             
                    end
         | 
| 33 27 |  | 
| 34 | 
            -
                    # Generate the migration  | 
| 35 | 
            -
                    def  | 
| 36 | 
            -
                       | 
| 37 | 
            -
             | 
| 28 | 
            +
                    # Generate the migration files using the templates
         | 
| 29 | 
            +
                    def create_migration_files
         | 
| 30 | 
            +
                      create_easy_ml_models_migration
         | 
| 31 | 
            +
                    end
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    private
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    # Generate the migration file for EasyMLModel using the template
         | 
| 36 | 
            +
                    def create_easy_ml_models_migration
         | 
| 37 | 
            +
                      migration_template(
         | 
| 38 | 
            +
                        "create_easy_ml_models.rb.tt",
         | 
| 39 | 
            +
                        "db/migrate/create_easy_ml_models.rb"
         | 
| 40 | 
            +
                      )
         | 
| 41 | 
            +
                    end
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # Get the next migration number
         | 
| 44 | 
            +
                    def next_migration_number
         | 
| 45 | 
            +
                      self.class.next_migration_number(Rails.root.join("db/migrate"))
         | 
| 38 46 | 
             
                    end
         | 
| 39 47 | 
             
                  end
         | 
| 40 48 | 
             
                end
         | 
| @@ -16,7 +16,7 @@ module EasyML | |
| 16 16 | 
             
                  files_to_check.each do |file|
         | 
| 17 17 | 
             
                    FileUtils.chown_R(`whoami`.chomp, "staff", file)
         | 
| 18 18 | 
             
                    FileUtils.chmod_R(0o777, file)
         | 
| 19 | 
            -
                    File.delete(file)  | 
| 19 | 
            +
                    File.delete(file) if @files_to_keep.exclude?(file) && File.exist?(file)
         | 
| 20 20 | 
             
                  end
         | 
| 21 21 | 
             
                end
         | 
| 22 22 | 
             
              end
         | 
    
        data/lib/easy_ml/version.rb
    CHANGED
    
    
    
        data/lib/easy_ml.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: easy_ml
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0.1. | 
| 4 | 
            +
              version: 0.1.3
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - Brett Shollenberger
         | 
| 8 8 | 
             
            autorequire:
         | 
| 9 9 | 
             
            bindir: exe
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date: 2024-10- | 
| 11 | 
            +
            date: 2024-10-18 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: activerecord
         | 
| @@ -58,20 +58,34 @@ dependencies: | |
| 58 58 | 
             
                - - "<"
         | 
| 59 59 | 
             
                  - !ruby/object:Gem::Version
         | 
| 60 60 | 
             
                    version: '4'
         | 
| 61 | 
            +
            - !ruby/object:Gem::Dependency
         | 
| 62 | 
            +
              name: fog
         | 
| 63 | 
            +
              requirement: !ruby/object:Gem::Requirement
         | 
| 64 | 
            +
                requirements:
         | 
| 65 | 
            +
                - - "~>"
         | 
| 66 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 67 | 
            +
                    version: '1.42'
         | 
| 68 | 
            +
              type: :runtime
         | 
| 69 | 
            +
              prerelease: false
         | 
| 70 | 
            +
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 71 | 
            +
                requirements:
         | 
| 72 | 
            +
                - - "~>"
         | 
| 73 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 74 | 
            +
                    version: '1.42'
         | 
| 61 75 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 62 76 | 
             
              name: fog-aws
         | 
| 63 77 | 
             
              requirement: !ruby/object:Gem::Requirement
         | 
| 64 78 | 
             
                requirements:
         | 
| 65 79 | 
             
                - - "~>"
         | 
| 66 80 | 
             
                  - !ruby/object:Gem::Version
         | 
| 67 | 
            -
                    version: ' | 
| 81 | 
            +
                    version: '2.0'
         | 
| 68 82 | 
             
              type: :runtime
         | 
| 69 83 | 
             
              prerelease: false
         | 
| 70 84 | 
             
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 71 85 | 
             
                requirements:
         | 
| 72 86 | 
             
                - - "~>"
         | 
| 73 87 | 
             
                  - !ruby/object:Gem::Version
         | 
| 74 | 
            -
                    version: ' | 
| 88 | 
            +
                    version: '2.0'
         | 
| 75 89 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 76 90 | 
             
              name: glue_gun_dsl
         | 
| 77 91 | 
             
              requirement: !ruby/object:Gem::Requirement
         | 
| @@ -156,6 +170,34 @@ dependencies: | |
| 156 170 | 
             
                - - ">="
         | 
| 157 171 | 
             
                  - !ruby/object:Gem::Version
         | 
| 158 172 | 
             
                    version: '0'
         | 
| 173 | 
            +
            - !ruby/object:Gem::Dependency
         | 
| 174 | 
            +
              name: wandb
         | 
| 175 | 
            +
              requirement: !ruby/object:Gem::Requirement
         | 
| 176 | 
            +
                requirements:
         | 
| 177 | 
            +
                - - "~>"
         | 
| 178 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 179 | 
            +
                    version: 0.1.6
         | 
| 180 | 
            +
              type: :runtime
         | 
| 181 | 
            +
              prerelease: false
         | 
| 182 | 
            +
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 183 | 
            +
                requirements:
         | 
| 184 | 
            +
                - - "~>"
         | 
| 185 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 186 | 
            +
                    version: 0.1.6
         | 
| 187 | 
            +
            - !ruby/object:Gem::Dependency
         | 
| 188 | 
            +
              name: xgb
         | 
| 189 | 
            +
              requirement: !ruby/object:Gem::Requirement
         | 
| 190 | 
            +
                requirements:
         | 
| 191 | 
            +
                - - ">="
         | 
| 192 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 193 | 
            +
                    version: '0'
         | 
| 194 | 
            +
              type: :runtime
         | 
| 195 | 
            +
              prerelease: false
         | 
| 196 | 
            +
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 197 | 
            +
                requirements:
         | 
| 198 | 
            +
                - - ">="
         | 
| 199 | 
            +
                  - !ruby/object:Gem::Version
         | 
| 200 | 
            +
                    version: '0'
         | 
| 159 201 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 160 202 | 
             
              name: annotate
         | 
| 161 203 | 
             
              requirement: !ruby/object:Gem::Requirement
         | 
| @@ -321,7 +363,6 @@ files: | |
| 321 363 | 
             
            - lib/easy_ml/support/git_ignorable.rb
         | 
| 322 364 | 
             
            - lib/easy_ml/support/synced_directory.rb
         | 
| 323 365 | 
             
            - lib/easy_ml/support/utc.rb
         | 
| 324 | 
            -
            - lib/easy_ml/trainer.rb
         | 
| 325 366 | 
             
            - lib/easy_ml/transforms.rb
         | 
| 326 367 | 
             
            - lib/easy_ml/version.rb
         | 
| 327 368 | 
             
            homepage: https://github.com/brettshollenberger/easy_ml
         | 
    
        data/lib/easy_ml/trainer.rb
    DELETED
    
    | @@ -1,92 +0,0 @@ | |
| 1 | 
            -
            module EasyML
         | 
| 2 | 
            -
              class Trainer
         | 
| 3 | 
            -
                # include GlueGun::DSL
         | 
| 4 | 
            -
                # include EasyML::Logging
         | 
| 5 | 
            -
             | 
| 6 | 
            -
                # define_attr :verbose, default: false
         | 
| 7 | 
            -
                # define_attr :root_dir do |root_dir|
         | 
| 8 | 
            -
                #   File.join(root_dir, "trainer")
         | 
| 9 | 
            -
                # end
         | 
| 10 | 
            -
             | 
| 11 | 
            -
                # define_config :dataset do |config|
         | 
| 12 | 
            -
                #   config.define_option :default do |option|
         | 
| 13 | 
            -
                #     option.set_class EasyML::Data::Dataset
         | 
| 14 | 
            -
                #     option.define_attr :root_dir
         | 
| 15 | 
            -
                #     option.define_attr :target
         | 
| 16 | 
            -
                #     option.define_attr :batch_size
         | 
| 17 | 
            -
                #   end
         | 
| 18 | 
            -
                # end
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                # define_config :model do |config|
         | 
| 21 | 
            -
                #   config.define_option :default do |option|
         | 
| 22 | 
            -
                #     option.set_class EasyML::Model
         | 
| 23 | 
            -
                #     option.define_attr :root_dir
         | 
| 24 | 
            -
                #     option.define_attr :name
         | 
| 25 | 
            -
                #     option.define_attr :hyperparameters
         | 
| 26 | 
            -
                #   end
         | 
| 27 | 
            -
                # end
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                # def train
         | 
| 30 | 
            -
                #   log_info("Starting training process") if verbose
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                #   dataset.refresh!
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                #   log_info("Fitting model") if verbose
         | 
| 35 | 
            -
                #   dataset.train(split_ys: true) do |xs, ys|
         | 
| 36 | 
            -
                #     model.fit(xs, ys)
         | 
| 37 | 
            -
                #   end
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                #   log_info("Saving model") if verbose
         | 
| 40 | 
            -
                #   model.save
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                #   log_info("Training completed") if verbose
         | 
| 43 | 
            -
                # end
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                # def evaluate
         | 
| 46 | 
            -
                #   log_info("Starting evaluation process") if verbose
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                #   results = {}
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                #   %i[train test valid].each do |split|
         | 
| 51 | 
            -
                #     log_info("Evaluating on #{split} set") if verbose
         | 
| 52 | 
            -
                #     predictions = []
         | 
| 53 | 
            -
                #     actuals = []
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                #     dataset.send(split, split_ys: true) do |xs, ys|
         | 
| 56 | 
            -
                #       batch_predictions = model.predict(xs)
         | 
| 57 | 
            -
                #       predictions.concat(batch_predictions.to_a)
         | 
| 58 | 
            -
                #       actuals.concat(ys.to_a)
         | 
| 59 | 
            -
                #     end
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                #     results[split] = calculate_metrics(predictions, actuals)
         | 
| 62 | 
            -
                #   end
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                #   log_info("Evaluation completed") if verbose
         | 
| 65 | 
            -
                #   results
         | 
| 66 | 
            -
                # end
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                # private
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                # def calculate_metrics(predictions, actuals)
         | 
| 71 | 
            -
                #   # Implement your metric calculations here
         | 
| 72 | 
            -
                #   # This is a placeholder and should be replaced with actual metric calculations
         | 
| 73 | 
            -
                #   {
         | 
| 74 | 
            -
                #     mse: mean_squared_error(predictions, actuals),
         | 
| 75 | 
            -
                #     mae: mean_absolute_error(predictions, actuals),
         | 
| 76 | 
            -
                #     r2: r_squared(predictions, actuals)
         | 
| 77 | 
            -
                #   }
         | 
| 78 | 
            -
                # end
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                # def mean_squared_error(predictions, actuals)
         | 
| 81 | 
            -
                #   # Implement MSE calculation
         | 
| 82 | 
            -
                # end
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                # def mean_absolute_error(predictions, actuals)
         | 
| 85 | 
            -
                #   # Implement MAE calculation
         | 
| 86 | 
            -
                # end
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                # def r_squared(predictions, actuals)
         | 
| 89 | 
            -
                #   # Implement R-squared calculation
         | 
| 90 | 
            -
                # end
         | 
| 91 | 
            -
              end
         | 
| 92 | 
            -
            end
         |