easy_ml 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/README.md +270 -0
- data/Rakefile +12 -0
- data/app/models/easy_ml/model.rb +59 -0
- data/app/models/easy_ml/models/xgboost.rb +9 -0
- data/app/models/easy_ml/models.rb +5 -0
- data/lib/easy_ml/core/model.rb +29 -0
- data/lib/easy_ml/core/model_core.rb +181 -0
- data/lib/easy_ml/core/model_evaluator.rb +137 -0
- data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
- data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
- data/lib/easy_ml/core/models/xgboost.rb +10 -0
- data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
- data/lib/easy_ml/core/models.rb +10 -0
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
- data/lib/easy_ml/core/tuner/adapters.rb +10 -0
- data/lib/easy_ml/core/tuner.rb +105 -0
- data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
- data/lib/easy_ml/core/uploaders.rb +7 -0
- data/lib/easy_ml/core.rb +9 -0
- data/lib/easy_ml/core_ext/pathname.rb +9 -0
- data/lib/easy_ml/core_ext.rb +5 -0
- data/lib/easy_ml/data/dataloader.rb +6 -0
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
- data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
- data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
- data/lib/easy_ml/data/dataset/splits.rb +11 -0
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
- data/lib/easy_ml/data/dataset/splitters.rb +9 -0
- data/lib/easy_ml/data/dataset.rb +430 -0
- data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
- data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
- data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
- data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
- data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
- data/lib/easy_ml/data/datasource.rb +33 -0
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
- data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
- data/lib/easy_ml/data/preprocessor.rb +238 -0
- data/lib/easy_ml/data/utils.rb +50 -0
- data/lib/easy_ml/data.rb +8 -0
- data/lib/easy_ml/deployment.rb +5 -0
- data/lib/easy_ml/engine.rb +26 -0
- data/lib/easy_ml/initializers/inflections.rb +4 -0
- data/lib/easy_ml/logging.rb +38 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
- data/lib/easy_ml/support/age.rb +27 -0
- data/lib/easy_ml/support/est.rb +1 -0
- data/lib/easy_ml/support/file_rotate.rb +23 -0
- data/lib/easy_ml/support/git_ignorable.rb +66 -0
- data/lib/easy_ml/support/synced_directory.rb +134 -0
- data/lib/easy_ml/support/utc.rb +1 -0
- data/lib/easy_ml/support.rb +10 -0
- data/lib/easy_ml/trainer.rb +92 -0
- data/lib/easy_ml/transforms.rb +29 -0
- data/lib/easy_ml/version.rb +5 -0
- data/lib/easy_ml.rb +23 -0
- metadata +353 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 7a959176791dac2307979438ad0f9a9319b4295fe25e489214df5f3b4c908466
|
4
|
+
data.tar.gz: c665ef3c19fda35197be653d9c14a34e9b2256d3c1b7387be4313086ba5d2c11
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: bda834230add0f3de2b57d8df4abde48f79f64275946b372d5069b49c610ab09c8cc1117a89a2d9206100b19b30c8f30bbfbe2d8973ce5db551626bfb612949d
|
7
|
+
data.tar.gz: 3d80ffc4930323f3b8cff51ef2e3475f557589e96d54717c750a63507c3c3ab682650c7ccb4daf2bd1168cb35627c0101579e8bd26b310e86906e134c6626249
|
data/README.md
ADDED
@@ -0,0 +1,270 @@
|
|
1
|
+
<img src="easy_ml.svg" alt="EasyML Logo" style="width: 310px; height: 300px;">
|
2
|
+
|
3
|
+
# EasyML
|
4
|
+
|
5
|
+
EasyML is a Ruby gem designed to simplify the process of building, deploying, and managing the lifecycle of machine learning models within a Ruby on Rails application. It is a plug-and-play, opinionated framework that currently supports XGBoost, with plans to expand support to a variety of models and infrastructures. EasyML aims to make deployment and lifecycle management straightforward and efficient.
|
6
|
+
|
7
|
+
## Features
|
8
|
+
|
9
|
+
- **Plug-and-Play Architecture**: EasyML is designed to be easily extendable, allowing for the integration of various machine learning models and data sources.
|
10
|
+
- **Opinionated Framework**: Provides a structured approach to model management, ensuring best practices are followed.
|
11
|
+
- **Model Lifecycle On Rails**: Seamlessly integrates with Ruby on Rails, allowing simplified deployment of models to production.
|
12
|
+
|
13
|
+
## Current and Planned Features
|
14
|
+
|
15
|
+
### Models Available
|
16
|
+
|
17
|
+
| XGBoost | LightGBM | TensorFlow | PyTorch |
|
18
|
+
| ------- | -------- | ---------- | ------- |
|
19
|
+
| ✅ | ❌ | ❌ | ❌ |
|
20
|
+
|
21
|
+
### Datasources Available
|
22
|
+
|
23
|
+
| S3 | File | Polars | SQL Databases | REST APIs |
|
24
|
+
| --- | ---- | ------ | ------------- | --------- |
|
25
|
+
| ✅ | ✅ | ✅ | ❌ | ❌ |
|
26
|
+
|
27
|
+
_Note: Features marked with ❌ are part of the roadmap and are not yet implemented._
|
28
|
+
|
29
|
+
## Quick Start:
|
30
|
+
|
31
|
+
Building a Production pipeline is as easy as 1,2,3!
|
32
|
+
|
33
|
+
### 1. Create Your Dataset
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
class MyDataset < EasyML::Data::Dataset
|
37
|
+
datasource :s3, s3_bucket: "my-bucket" # Every time the data changes, we'll pull new data
|
38
|
+
target "revenue" # What are we trying to predict?
|
39
|
+
splitter :date, date_column: "created_at" # How should we partition data into training, test, and validation datasets?
|
40
|
+
transforms DataPipeline # Class that manages data transformation, adding new columns, etc.
|
41
|
+
preprocessing_steps({
|
42
|
+
training: {
|
43
|
+
annual_revenue: { median: true, clip: { min: 0, max: 500_000 } }
|
44
|
+
}
|
45
|
+
}) # If annual revenue is missing, use the median value, after clipping the values into the approved list
|
46
|
+
end
|
47
|
+
```
|
48
|
+
|
49
|
+
### 2. Create a Model
|
50
|
+
|
51
|
+
```ruby
|
52
|
+
class MyModel < EasyML::Models::XGBoost
|
53
|
+
dataset MyDataset
|
54
|
+
task :regression # Or classification
|
55
|
+
hyperparameters({
|
56
|
+
max_depth: 5,
|
57
|
+
learning_rate: 0.1,
|
58
|
+
objective: "reg:squarederror"
|
59
|
+
})
|
60
|
+
end
|
61
|
+
```
|
62
|
+
|
63
|
+
### 3. Create a Trainer
|
64
|
+
|
65
|
+
```ruby
|
66
|
+
class MyTrainer < EasyML::Trainer
|
67
|
+
model MyModel
|
68
|
+
evaluator MyMetrics
|
69
|
+
end
|
70
|
+
|
71
|
+
class MyMetrics
|
72
|
+
def metric_we_make_money(y_pred, y_true)
|
73
|
+
return true if model_makes_money?
|
74
|
+
return false if model_lose_money?
|
75
|
+
end
|
76
|
+
|
77
|
+
def metric_sales_team_has_enough_leads(y_pred, y_true)
|
78
|
+
return false if sales_will_be_sitting_on_their_hands?
|
79
|
+
end
|
80
|
+
end
|
81
|
+
```
|
82
|
+
|
83
|
+
Now you're ready to predict in production!
|
84
|
+
|
85
|
+
```ruby
|
86
|
+
MyTrainer.train # Yay, we did it!
|
87
|
+
MyTrainer.deploy # Let the production hosts know it's live!
|
88
|
+
MyTrainer.predict(customer_data: "I am worth a lot of money")
|
89
|
+
# prediction: true!
|
90
|
+
```
|
91
|
+
|
92
|
+
## Data Management
|
93
|
+
|
94
|
+
EasyML provides a comprehensive data management system that handles all preprocessing tasks, including splitting data into train, test, and validation sets, and avoiding data leakage. The primary abstraction for data handling is the `Dataset` class, which ensures data is properly managed and prepared for machine learning tasks.
|
95
|
+
|
96
|
+
### Preprocessing Features
|
97
|
+
|
98
|
+
EasyML offers a variety of preprocessing features to prepare your data for machine learning models. Here's a complete list of available preprocessing steps and examples of when to use them:
|
99
|
+
|
100
|
+
- **Mean Imputation**: Replace missing values with the mean of the feature. Use this when you want to maintain the average value of the data.
|
101
|
+
|
102
|
+
```ruby
|
103
|
+
annual_revenue: {
|
104
|
+
mean: true
|
105
|
+
}
|
106
|
+
```
|
107
|
+
|
108
|
+
- **Median Imputation**: Replace missing values with the median of the feature. This is useful when you want to maintain the central tendency of the data without being affected by outliers.
|
109
|
+
|
110
|
+
```ruby
|
111
|
+
annual_revenue: {
|
112
|
+
median: true
|
113
|
+
}
|
114
|
+
```
|
115
|
+
|
116
|
+
- **Forward Fill (ffill)**: Fill missing values with the last observed value. Use this for time series data where the last known value is a reasonable estimate for missing values.
|
117
|
+
|
118
|
+
```ruby
|
119
|
+
created_date: {
|
120
|
+
ffill: true
|
121
|
+
}
|
122
|
+
```
|
123
|
+
|
124
|
+
- **Most Frequent Imputation**: Replace missing values with the most frequently occurring value. This is useful for categorical data where the mode is a reasonable estimate for missing values.
|
125
|
+
|
126
|
+
```ruby
|
127
|
+
loan_purpose: {
|
128
|
+
most_frequent: true
|
129
|
+
}
|
130
|
+
```
|
131
|
+
|
132
|
+
- **Constant Imputation**: Replace missing values with a constant value. Use this when you have a specific value that should be used for missing data.
|
133
|
+
|
134
|
+
```ruby
|
135
|
+
loan_purpose: {
|
136
|
+
constant: { fill_value: 'unknown' }
|
137
|
+
}
|
138
|
+
```
|
139
|
+
|
140
|
+
- **Today Imputation**: Fill missing date values with the current date. Use this for features that should default to the current date.
|
141
|
+
|
142
|
+
```ruby
|
143
|
+
created_date: {
|
144
|
+
today: true
|
145
|
+
}
|
146
|
+
```
|
147
|
+
|
148
|
+
- **One-Hot Encoding**: Convert categorical variables into a set of binary variables. Use this when you have categorical data that needs to be converted into a numerical format for model training.
|
149
|
+
|
150
|
+
```ruby
|
151
|
+
loan_purpose: {
|
152
|
+
one_hot: true
|
153
|
+
}
|
154
|
+
```
|
155
|
+
|
156
|
+
- **Label Encoding**: Convert categorical variables into integer labels. Use this when you have categorical data that can be ordinally encoded.
|
157
|
+
|
158
|
+
```ruby
|
159
|
+
loan_purpose: {
|
160
|
+
categorical: {
|
161
|
+
encode_labels: true
|
162
|
+
}
|
163
|
+
}
|
164
|
+
```
|
165
|
+
|
166
|
+
### Other Dataset Features
|
167
|
+
|
168
|
+
- **Data Splitting**: Automatically split data into train, test, and validation sets using various strategies, such as date-based splitting.
|
169
|
+
- **Data Synchronization**: Ensure data is synced from its source, such as S3 or local files.
|
170
|
+
- **Batch Processing**: Process data in batches to handle large datasets efficiently.
|
171
|
+
- **Null Handling**: Alert and handle null values in datasets to ensure data quality.
|
172
|
+
|
173
|
+
## Installation
|
174
|
+
|
175
|
+
Install necessary Python dependencies
|
176
|
+
|
177
|
+
1. **Install Python dependencies (don't worry, all code is in Ruby, we just call through to Python)**
|
178
|
+
|
179
|
+
```bash
|
180
|
+
pip install wandb
|
181
|
+
pip install optuna
|
182
|
+
```
|
183
|
+
|
184
|
+
1. **Install the gem**:
|
185
|
+
|
186
|
+
```bash
|
187
|
+
gem install easy_ml
|
188
|
+
```
|
189
|
+
|
190
|
+
2. **Run the generator to store model versions**:
|
191
|
+
|
192
|
+
```bash
|
193
|
+
rails generate easy_ml:migration
|
194
|
+
rails db:migrate
|
195
|
+
```
|
196
|
+
|
197
|
+
3. **Configure CarrierWave for S3 storage**:
|
198
|
+
|
199
|
+
Ensure you have CarrierWave configured to use AWS S3. If not, add the following configuration:
|
200
|
+
|
201
|
+
```ruby
|
202
|
+
# config/initializers/carrierwave.rb
|
203
|
+
CarrierWave.configure do |config|
|
204
|
+
config.fog_provider = 'fog/aws'
|
205
|
+
config.fog_credentials = {
|
206
|
+
provider: 'AWS',
|
207
|
+
aws_access_key_id: ENV['AWS_ACCESS_KEY_ID'],
|
208
|
+
aws_secret_access_key: ENV['AWS_SECRET_ACCESS_KEY'],
|
209
|
+
region: ENV['AWS_REGION'],
|
210
|
+
}
|
211
|
+
config.fog_directory = ENV['AWS_S3_BUCKET']
|
212
|
+
config.fog_public = false
|
213
|
+
config.storage = :fog
|
214
|
+
end
|
215
|
+
```
|
216
|
+
|
217
|
+
## Usage
|
218
|
+
|
219
|
+
To use EasyML in your Rails application, follow these steps:
|
220
|
+
|
221
|
+
1. **Define your preprocessing steps** in a configuration hash. For example:
|
222
|
+
|
223
|
+
```ruby
|
224
|
+
preprocessing_steps = {
|
225
|
+
training: {
|
226
|
+
annual_revenue: {
|
227
|
+
median: true,
|
228
|
+
clip: { min: 0, max: 1_000_000 }
|
229
|
+
},
|
230
|
+
loan_purpose: {
|
231
|
+
categorical: {
|
232
|
+
categorical_min: 2,
|
233
|
+
one_hot: true
|
234
|
+
}
|
235
|
+
}
|
236
|
+
}
|
237
|
+
}
|
238
|
+
```
|
239
|
+
|
240
|
+
2. **Create a dataset** using the `EasyML::Data::Dataset` class, providing necessary configurations such as data source, target, and preprocessing steps.
|
241
|
+
|
242
|
+
3. **Train a model** using the `EasyML::Models` module, specifying the model class and configuration.
|
243
|
+
|
244
|
+
4. **Deploy the model** by marking it as live and storing it in the configured S3 bucket.
|
245
|
+
|
246
|
+
## Development
|
247
|
+
|
248
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
249
|
+
|
250
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
251
|
+
|
252
|
+
## Contributing
|
253
|
+
|
254
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/easy_ml. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [code of conduct](https://github.com/[USERNAME]/easy_ml/blob/main/CODE_OF_CONDUCT.md).
|
255
|
+
|
256
|
+
## License
|
257
|
+
|
258
|
+
The gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
|
259
|
+
|
260
|
+
## Code of Conduct
|
261
|
+
|
262
|
+
Everyone interacting in the EasyML project's codebases, issue trackers, chat rooms, and mailing lists is expected to follow the [code of conduct](https://github.com/[USERNAME]/easy_ml/blob/main/CODE_OF_CONDUCT.md).
|
263
|
+
|
264
|
+
## Expected Future Enhancements
|
265
|
+
|
266
|
+
- **Support for Additional Models**: Integration with LightGBM, TensorFlow, and PyTorch.
|
267
|
+
- **Expanded Data Source Support**: Ability to pull data from SQL databases and REST APIs.
|
268
|
+
- **Enhanced Deployment Options**: More flexible deployment strategies and integration with CI/CD pipelines.
|
269
|
+
- **Advanced Monitoring and Logging**: Improved tools for monitoring model performance and logging.
|
270
|
+
- **User Interface Improvements**: Enhanced UI components for managing models and datasets.
|
data/Rakefile
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
require_relative "../../../lib/easy_ml/core/model"
|
2
|
+
module EasyML
|
3
|
+
class Model < ActiveRecord::Base
|
4
|
+
include EasyML::Core::ModelCore
|
5
|
+
|
6
|
+
self.table_name = "easy_ml_models"
|
7
|
+
|
8
|
+
scope :live, -> { where(is_live: true) }
|
9
|
+
attribute :root_dir, :string
|
10
|
+
after_initialize :apply_defaults
|
11
|
+
|
12
|
+
validate :only_one_model_is_live?
|
13
|
+
def only_one_model_is_live?
|
14
|
+
return if @marking_live
|
15
|
+
|
16
|
+
if previous_versions.live.count > 1
|
17
|
+
raise "Multiple previous versions of #{name} are live! This should never happen. Update previous versions to is_live=false before proceeding"
|
18
|
+
end
|
19
|
+
|
20
|
+
return unless previous_versions.live.any? && is_live
|
21
|
+
|
22
|
+
errors.add(:is_live,
|
23
|
+
"cannot mark model live when previous version is live. Explicitly use the mark_live method to mark this as the live version")
|
24
|
+
end
|
25
|
+
|
26
|
+
def mark_live
|
27
|
+
transaction do
|
28
|
+
self.class.where(name: name).where.not(id: id).update_all(is_live: false)
|
29
|
+
self.class.where(id: id).update_all(is_live: true)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def previous_versions
|
34
|
+
EasyML::Model.where(name: name).order(id: :desc)
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def files_to_keep
|
40
|
+
live_models = self.class.live
|
41
|
+
|
42
|
+
recent_copies = live_models.flat_map do |live|
|
43
|
+
# Fetch all models with the same name
|
44
|
+
self.class.where(name: live.name).where(is_live: false).order(created_at: :desc).limit(live.name == name ? 4 : 5)
|
45
|
+
end
|
46
|
+
|
47
|
+
recent_versions = self.class
|
48
|
+
.where.not(
|
49
|
+
"EXISTS (SELECT 1 FROM easy_ml_models e2 WHERE e2.name = easy_ml_models.name AND e2.is_live = true)"
|
50
|
+
)
|
51
|
+
.where("created_at >= ?", 2.days.ago)
|
52
|
+
.order(created_at: :desc)
|
53
|
+
.group_by(&:name)
|
54
|
+
.flat_map { |_, models| models.take(5) }
|
55
|
+
|
56
|
+
([self] + recent_versions + recent_copies + live_models).compact.map(&:file).map(&:path).uniq
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
require "carrierwave"
|
2
|
+
require_relative "model_core"
|
3
|
+
require_relative "uploaders/model_uploader"
|
4
|
+
|
5
|
+
module EasyML
|
6
|
+
module Core
|
7
|
+
class Model
|
8
|
+
include GlueGun::DSL
|
9
|
+
|
10
|
+
attribute :name, :string
|
11
|
+
attribute :version, :string
|
12
|
+
attribute :task, :string, default: "regression"
|
13
|
+
attribute :metrics, :array
|
14
|
+
attribute :ml_model, :string
|
15
|
+
attribute :file, :string
|
16
|
+
attribute :root_dir, :string
|
17
|
+
attribute :objective
|
18
|
+
attribute :evaluator
|
19
|
+
attribute :evaluator_metric
|
20
|
+
|
21
|
+
include EasyML::Core::ModelCore
|
22
|
+
|
23
|
+
def initialize(options = {})
|
24
|
+
super
|
25
|
+
apply_defaults
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,181 @@
|
|
1
|
+
require "carrierwave"
|
2
|
+
require_relative "uploaders/model_uploader"
|
3
|
+
|
4
|
+
module EasyML
|
5
|
+
module Core
|
6
|
+
module ModelCore
|
7
|
+
attr_accessor :dataset
|
8
|
+
|
9
|
+
def self.included(base)
|
10
|
+
base.send(:include, GlueGun::DSL)
|
11
|
+
base.send(:extend, CarrierWave::Mount)
|
12
|
+
base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
|
13
|
+
|
14
|
+
base.class_eval do
|
15
|
+
validates :task, inclusion: { in: %w[regression classification] }
|
16
|
+
validates :task, presence: true
|
17
|
+
validate :dataset_is_a_dataset?
|
18
|
+
validate :validate_any_metrics?
|
19
|
+
validate :validate_metrics_for_task
|
20
|
+
before_validation :save_model_file, if: -> { fit? }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def dataset_is_a_dataset?
|
25
|
+
return if dataset.nil?
|
26
|
+
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
27
|
+
|
28
|
+
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
29
|
+
end
|
30
|
+
|
31
|
+
def validate_any_metrics?
|
32
|
+
return if metrics.any?
|
33
|
+
|
34
|
+
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
35
|
+
end
|
36
|
+
|
37
|
+
def validate_metrics_for_task
|
38
|
+
nonsensical_metrics = metrics.select do |metric|
|
39
|
+
allowed_metrics.exclude?(metric)
|
40
|
+
end
|
41
|
+
|
42
|
+
return unless nonsensical_metrics.any?
|
43
|
+
|
44
|
+
errors.add(:metrics,
|
45
|
+
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
46
|
+
end
|
47
|
+
|
48
|
+
def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
|
49
|
+
if x_train.nil?
|
50
|
+
dataset.refresh!
|
51
|
+
train_in_batches
|
52
|
+
else
|
53
|
+
train(x_train, y_train, x_valid, y_valid)
|
54
|
+
end
|
55
|
+
@is_fit = true
|
56
|
+
end
|
57
|
+
|
58
|
+
def decode_labels(ys, col: nil)
|
59
|
+
dataset.decode_labels(ys, col: col)
|
60
|
+
end
|
61
|
+
|
62
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
63
|
+
evaluator ||= self.evaluator
|
64
|
+
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
65
|
+
evaluator: evaluator)
|
66
|
+
end
|
67
|
+
|
68
|
+
def predict(xs)
|
69
|
+
raise NotImplementedError, "Subclasses must implement predict method"
|
70
|
+
end
|
71
|
+
|
72
|
+
def load
|
73
|
+
raise NotImplementedError, "Subclasses must implement load method"
|
74
|
+
end
|
75
|
+
|
76
|
+
def _save_model_file
|
77
|
+
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
78
|
+
end
|
79
|
+
|
80
|
+
def save
|
81
|
+
super if defined?(super) && self.class.superclass.method_defined?(:save)
|
82
|
+
save_model_file
|
83
|
+
end
|
84
|
+
|
85
|
+
def save_model_file
|
86
|
+
raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
|
87
|
+
|
88
|
+
path = File.join(model_dir, "#{version}.json")
|
89
|
+
ensure_directory_exists(File.dirname(path))
|
90
|
+
|
91
|
+
_save_model_file(path)
|
92
|
+
|
93
|
+
File.open(path) do |f|
|
94
|
+
self.file = f
|
95
|
+
end
|
96
|
+
file.store!
|
97
|
+
|
98
|
+
cleanup
|
99
|
+
end
|
100
|
+
|
101
|
+
def get_params
|
102
|
+
@hyperparameters.to_h
|
103
|
+
end
|
104
|
+
|
105
|
+
def allowed_metrics
|
106
|
+
return [] unless task.present?
|
107
|
+
|
108
|
+
case task.to_sym
|
109
|
+
when :regression
|
110
|
+
%w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
|
111
|
+
when :classification
|
112
|
+
%w[accuracy_score precision_score recall_score f1_score auc roc_auc]
|
113
|
+
else
|
114
|
+
[]
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def cleanup!
|
119
|
+
[file_dir, model_dir].each do |dir|
|
120
|
+
EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
def cleanup
|
125
|
+
[file_dir, model_dir].each do |dir|
|
126
|
+
EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
def fit?
|
131
|
+
@is_fit == true
|
132
|
+
end
|
133
|
+
|
134
|
+
private
|
135
|
+
|
136
|
+
def file_dir
|
137
|
+
return unless file.path.present?
|
138
|
+
|
139
|
+
File.dirname(file.path).split("/")[0..-2].join("/")
|
140
|
+
end
|
141
|
+
|
142
|
+
def extension_allowlist
|
143
|
+
EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
|
144
|
+
end
|
145
|
+
|
146
|
+
def _save_model_file(path = nil)
|
147
|
+
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
148
|
+
end
|
149
|
+
|
150
|
+
def ensure_directory_exists(dir)
|
151
|
+
FileUtils.mkdir_p(dir) unless File.directory?(dir)
|
152
|
+
end
|
153
|
+
|
154
|
+
def apply_defaults
|
155
|
+
self.version ||= generate_version_string
|
156
|
+
self.metrics ||= allowed_metrics
|
157
|
+
self.ml_model ||= get_ml_model
|
158
|
+
end
|
159
|
+
|
160
|
+
def get_ml_model
|
161
|
+
self.class.name.split("::").last.underscore
|
162
|
+
end
|
163
|
+
|
164
|
+
def generate_version_string
|
165
|
+
timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
|
166
|
+
model_name = self.class.name.split("::").last.underscore
|
167
|
+
"#{model_name}_#{timestamp}"
|
168
|
+
end
|
169
|
+
|
170
|
+
def model_dir
|
171
|
+
File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
|
172
|
+
end
|
173
|
+
|
174
|
+
def files_to_keep
|
175
|
+
Dir.glob(File.join(file_dir, "*")).select { |f| File.file?(f) }.sort_by do |filename|
|
176
|
+
Time.parse(filename.split("/").last.gsub(/\D/, ""))
|
177
|
+
end.reverse.take(5)
|
178
|
+
end
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|