easy_ml 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +270 -0
- data/Rakefile +12 -0
- data/app/models/easy_ml/model.rb +59 -0
- data/app/models/easy_ml/models/xgboost.rb +9 -0
- data/app/models/easy_ml/models.rb +5 -0
- data/lib/easy_ml/core/model.rb +29 -0
- data/lib/easy_ml/core/model_core.rb +181 -0
- data/lib/easy_ml/core/model_evaluator.rb +137 -0
- data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
- data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
- data/lib/easy_ml/core/models/xgboost.rb +10 -0
- data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
- data/lib/easy_ml/core/models.rb +10 -0
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
- data/lib/easy_ml/core/tuner/adapters.rb +10 -0
- data/lib/easy_ml/core/tuner.rb +105 -0
- data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
- data/lib/easy_ml/core/uploaders.rb +7 -0
- data/lib/easy_ml/core.rb +9 -0
- data/lib/easy_ml/core_ext/pathname.rb +9 -0
- data/lib/easy_ml/core_ext.rb +5 -0
- data/lib/easy_ml/data/dataloader.rb +6 -0
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
- data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
- data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
- data/lib/easy_ml/data/dataset/splits.rb +11 -0
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
- data/lib/easy_ml/data/dataset/splitters.rb +9 -0
- data/lib/easy_ml/data/dataset.rb +430 -0
- data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
- data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
- data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
- data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
- data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
- data/lib/easy_ml/data/datasource.rb +33 -0
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
- data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
- data/lib/easy_ml/data/preprocessor.rb +238 -0
- data/lib/easy_ml/data/utils.rb +50 -0
- data/lib/easy_ml/data.rb +8 -0
- data/lib/easy_ml/deployment.rb +5 -0
- data/lib/easy_ml/engine.rb +26 -0
- data/lib/easy_ml/initializers/inflections.rb +4 -0
- data/lib/easy_ml/logging.rb +38 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
- data/lib/easy_ml/support/age.rb +27 -0
- data/lib/easy_ml/support/est.rb +1 -0
- data/lib/easy_ml/support/file_rotate.rb +23 -0
- data/lib/easy_ml/support/git_ignorable.rb +66 -0
- data/lib/easy_ml/support/synced_directory.rb +134 -0
- data/lib/easy_ml/support/utc.rb +1 -0
- data/lib/easy_ml/support.rb +10 -0
- data/lib/easy_ml/trainer.rb +92 -0
- data/lib/easy_ml/transforms.rb +29 -0
- data/lib/easy_ml/version.rb +5 -0
- data/lib/easy_ml.rb +23 -0
- metadata +353 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 7a959176791dac2307979438ad0f9a9319b4295fe25e489214df5f3b4c908466
|
4
|
+
data.tar.gz: c665ef3c19fda35197be653d9c14a34e9b2256d3c1b7387be4313086ba5d2c11
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: bda834230add0f3de2b57d8df4abde48f79f64275946b372d5069b49c610ab09c8cc1117a89a2d9206100b19b30c8f30bbfbe2d8973ce5db551626bfb612949d
|
7
|
+
data.tar.gz: 3d80ffc4930323f3b8cff51ef2e3475f557589e96d54717c750a63507c3c3ab682650c7ccb4daf2bd1168cb35627c0101579e8bd26b310e86906e134c6626249
|
data/README.md
ADDED
@@ -0,0 +1,270 @@
|
|
1
|
+
<img src="easy_ml.svg" alt="EasyML Logo" style="width: 310px; height: 300px;">
|
2
|
+
|
3
|
+
# EasyML
|
4
|
+
|
5
|
+
EasyML is a Ruby gem designed to simplify the process of building, deploying, and managing the lifecycle of machine learning models within a Ruby on Rails application. It is a plug-and-play, opinionated framework that currently supports XGBoost, with plans to expand support to a variety of models and infrastructures. EasyML aims to make deployment and lifecycle management straightforward and efficient.
|
6
|
+
|
7
|
+
## Features
|
8
|
+
|
9
|
+
- **Plug-and-Play Architecture**: EasyML is designed to be easily extendable, allowing for the integration of various machine learning models and data sources.
|
10
|
+
- **Opinionated Framework**: Provides a structured approach to model management, ensuring best practices are followed.
|
11
|
+
- **Model Lifecycle On Rails**: Seamlessly integrates with Ruby on Rails, allowing simplified deployment of models to production.
|
12
|
+
|
13
|
+
## Current and Planned Features
|
14
|
+
|
15
|
+
### Models Available
|
16
|
+
|
17
|
+
| XGBoost | LightGBM | TensorFlow | PyTorch |
|
18
|
+
| ------- | -------- | ---------- | ------- |
|
19
|
+
| ✅ | ❌ | ❌ | ❌ |
|
20
|
+
|
21
|
+
### Datasources Available
|
22
|
+
|
23
|
+
| S3 | File | Polars | SQL Databases | REST APIs |
|
24
|
+
| --- | ---- | ------ | ------------- | --------- |
|
25
|
+
| ✅ | ✅ | ✅ | ❌ | ❌ |
|
26
|
+
|
27
|
+
_Note: Features marked with ❌ are part of the roadmap and are not yet implemented._
|
28
|
+
|
29
|
+
## Quick Start:
|
30
|
+
|
31
|
+
Building a Production pipeline is as easy as 1,2,3!
|
32
|
+
|
33
|
+
### 1. Create Your Dataset
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
class MyDataset < EasyML::Data::Dataset
|
37
|
+
datasource :s3, s3_bucket: "my-bucket" # Every time the data changes, we'll pull new data
|
38
|
+
target "revenue" # What are we trying to predict?
|
39
|
+
splitter :date, date_column: "created_at" # How should we partition data into training, test, and validation datasets?
|
40
|
+
transforms DataPipeline # Class that manages data transformation, adding new columns, etc.
|
41
|
+
preprocessing_steps({
|
42
|
+
training: {
|
43
|
+
annual_revenue: { median: true, clip: { min: 0, max: 500_000 } }
|
44
|
+
}
|
45
|
+
}) # If annual revenue is missing, use the median value, after clipping the values into the approved list
|
46
|
+
end
|
47
|
+
```
|
48
|
+
|
49
|
+
### 2. Create a Model
|
50
|
+
|
51
|
+
```ruby
|
52
|
+
class MyModel < EasyML::Models::XGBoost
|
53
|
+
dataset MyDataset
|
54
|
+
task :regression # Or classification
|
55
|
+
hyperparameters({
|
56
|
+
max_depth: 5,
|
57
|
+
learning_rate: 0.1,
|
58
|
+
objective: "reg:squarederror"
|
59
|
+
})
|
60
|
+
end
|
61
|
+
```
|
62
|
+
|
63
|
+
### 3. Create a Trainer
|
64
|
+
|
65
|
+
```ruby
|
66
|
+
class MyTrainer < EasyML::Trainer
|
67
|
+
model MyModel
|
68
|
+
evaluator MyMetrics
|
69
|
+
end
|
70
|
+
|
71
|
+
class MyMetrics
|
72
|
+
def metric_we_make_money(y_pred, y_true)
|
73
|
+
return true if model_makes_money?
|
74
|
+
return false if model_lose_money?
|
75
|
+
end
|
76
|
+
|
77
|
+
def metric_sales_team_has_enough_leads(y_pred, y_true)
|
78
|
+
return false if sales_will_be_sitting_on_their_hands?
|
79
|
+
end
|
80
|
+
end
|
81
|
+
```
|
82
|
+
|
83
|
+
Now you're ready to predict in production!
|
84
|
+
|
85
|
+
```ruby
|
86
|
+
MyTrainer.train # Yay, we did it!
|
87
|
+
MyTrainer.deploy # Let the production hosts know it's live!
|
88
|
+
MyTrainer.predict(customer_data: "I am worth a lot of money")
|
89
|
+
# prediction: true!
|
90
|
+
```
|
91
|
+
|
92
|
+
## Data Management
|
93
|
+
|
94
|
+
EasyML provides a comprehensive data management system that handles all preprocessing tasks, including splitting data into train, test, and validation sets, and avoiding data leakage. The primary abstraction for data handling is the `Dataset` class, which ensures data is properly managed and prepared for machine learning tasks.
|
95
|
+
|
96
|
+
### Preprocessing Features
|
97
|
+
|
98
|
+
EasyML offers a variety of preprocessing features to prepare your data for machine learning models. Here's a complete list of available preprocessing steps and examples of when to use them:
|
99
|
+
|
100
|
+
- **Mean Imputation**: Replace missing values with the mean of the feature. Use this when you want to maintain the average value of the data.
|
101
|
+
|
102
|
+
```ruby
|
103
|
+
annual_revenue: {
|
104
|
+
mean: true
|
105
|
+
}
|
106
|
+
```
|
107
|
+
|
108
|
+
- **Median Imputation**: Replace missing values with the median of the feature. This is useful when you want to maintain the central tendency of the data without being affected by outliers.
|
109
|
+
|
110
|
+
```ruby
|
111
|
+
annual_revenue: {
|
112
|
+
median: true
|
113
|
+
}
|
114
|
+
```
|
115
|
+
|
116
|
+
- **Forward Fill (ffill)**: Fill missing values with the last observed value. Use this for time series data where the last known value is a reasonable estimate for missing values.
|
117
|
+
|
118
|
+
```ruby
|
119
|
+
created_date: {
|
120
|
+
ffill: true
|
121
|
+
}
|
122
|
+
```
|
123
|
+
|
124
|
+
- **Most Frequent Imputation**: Replace missing values with the most frequently occurring value. This is useful for categorical data where the mode is a reasonable estimate for missing values.
|
125
|
+
|
126
|
+
```ruby
|
127
|
+
loan_purpose: {
|
128
|
+
most_frequent: true
|
129
|
+
}
|
130
|
+
```
|
131
|
+
|
132
|
+
- **Constant Imputation**: Replace missing values with a constant value. Use this when you have a specific value that should be used for missing data.
|
133
|
+
|
134
|
+
```ruby
|
135
|
+
loan_purpose: {
|
136
|
+
constant: { fill_value: 'unknown' }
|
137
|
+
}
|
138
|
+
```
|
139
|
+
|
140
|
+
- **Today Imputation**: Fill missing date values with the current date. Use this for features that should default to the current date.
|
141
|
+
|
142
|
+
```ruby
|
143
|
+
created_date: {
|
144
|
+
today: true
|
145
|
+
}
|
146
|
+
```
|
147
|
+
|
148
|
+
- **One-Hot Encoding**: Convert categorical variables into a set of binary variables. Use this when you have categorical data that needs to be converted into a numerical format for model training.
|
149
|
+
|
150
|
+
```ruby
|
151
|
+
loan_purpose: {
|
152
|
+
one_hot: true
|
153
|
+
}
|
154
|
+
```
|
155
|
+
|
156
|
+
- **Label Encoding**: Convert categorical variables into integer labels. Use this when you have categorical data that can be ordinally encoded.
|
157
|
+
|
158
|
+
```ruby
|
159
|
+
loan_purpose: {
|
160
|
+
categorical: {
|
161
|
+
encode_labels: true
|
162
|
+
}
|
163
|
+
}
|
164
|
+
```
|
165
|
+
|
166
|
+
### Other Dataset Features
|
167
|
+
|
168
|
+
- **Data Splitting**: Automatically split data into train, test, and validation sets using various strategies, such as date-based splitting.
|
169
|
+
- **Data Synchronization**: Ensure data is synced from its source, such as S3 or local files.
|
170
|
+
- **Batch Processing**: Process data in batches to handle large datasets efficiently.
|
171
|
+
- **Null Handling**: Alert and handle null values in datasets to ensure data quality.
|
172
|
+
|
173
|
+
## Installation
|
174
|
+
|
175
|
+
Install necessary Python dependencies
|
176
|
+
|
177
|
+
1. **Install Python dependencies (don't worry, all code is in Ruby, we just call through to Python)**
|
178
|
+
|
179
|
+
```bash
|
180
|
+
pip install wandb
|
181
|
+
pip install optuna
|
182
|
+
```
|
183
|
+
|
184
|
+
1. **Install the gem**:
|
185
|
+
|
186
|
+
```bash
|
187
|
+
gem install easy_ml
|
188
|
+
```
|
189
|
+
|
190
|
+
2. **Run the generator to store model versions**:
|
191
|
+
|
192
|
+
```bash
|
193
|
+
rails generate easy_ml:migration
|
194
|
+
rails db:migrate
|
195
|
+
```
|
196
|
+
|
197
|
+
3. **Configure CarrierWave for S3 storage**:
|
198
|
+
|
199
|
+
Ensure you have CarrierWave configured to use AWS S3. If not, add the following configuration:
|
200
|
+
|
201
|
+
```ruby
|
202
|
+
# config/initializers/carrierwave.rb
|
203
|
+
CarrierWave.configure do |config|
|
204
|
+
config.fog_provider = 'fog/aws'
|
205
|
+
config.fog_credentials = {
|
206
|
+
provider: 'AWS',
|
207
|
+
aws_access_key_id: ENV['AWS_ACCESS_KEY_ID'],
|
208
|
+
aws_secret_access_key: ENV['AWS_SECRET_ACCESS_KEY'],
|
209
|
+
region: ENV['AWS_REGION'],
|
210
|
+
}
|
211
|
+
config.fog_directory = ENV['AWS_S3_BUCKET']
|
212
|
+
config.fog_public = false
|
213
|
+
config.storage = :fog
|
214
|
+
end
|
215
|
+
```
|
216
|
+
|
217
|
+
## Usage
|
218
|
+
|
219
|
+
To use EasyML in your Rails application, follow these steps:
|
220
|
+
|
221
|
+
1. **Define your preprocessing steps** in a configuration hash. For example:
|
222
|
+
|
223
|
+
```ruby
|
224
|
+
preprocessing_steps = {
|
225
|
+
training: {
|
226
|
+
annual_revenue: {
|
227
|
+
median: true,
|
228
|
+
clip: { min: 0, max: 1_000_000 }
|
229
|
+
},
|
230
|
+
loan_purpose: {
|
231
|
+
categorical: {
|
232
|
+
categorical_min: 2,
|
233
|
+
one_hot: true
|
234
|
+
}
|
235
|
+
}
|
236
|
+
}
|
237
|
+
}
|
238
|
+
```
|
239
|
+
|
240
|
+
2. **Create a dataset** using the `EasyML::Data::Dataset` class, providing necessary configurations such as data source, target, and preprocessing steps.
|
241
|
+
|
242
|
+
3. **Train a model** using the `EasyML::Models` module, specifying the model class and configuration.
|
243
|
+
|
244
|
+
4. **Deploy the model** by marking it as live and storing it in the configured S3 bucket.
|
245
|
+
|
246
|
+
## Development
|
247
|
+
|
248
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
249
|
+
|
250
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
251
|
+
|
252
|
+
## Contributing
|
253
|
+
|
254
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/easy_ml. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [code of conduct](https://github.com/[USERNAME]/easy_ml/blob/main/CODE_OF_CONDUCT.md).
|
255
|
+
|
256
|
+
## License
|
257
|
+
|
258
|
+
The gem is available as open source under the terms of the [MIT License](https://opensource.org/licenses/MIT).
|
259
|
+
|
260
|
+
## Code of Conduct
|
261
|
+
|
262
|
+
Everyone interacting in the EasyML project's codebases, issue trackers, chat rooms, and mailing lists is expected to follow the [code of conduct](https://github.com/[USERNAME]/easy_ml/blob/main/CODE_OF_CONDUCT.md).
|
263
|
+
|
264
|
+
## Expected Future Enhancements
|
265
|
+
|
266
|
+
- **Support for Additional Models**: Integration with LightGBM, TensorFlow, and PyTorch.
|
267
|
+
- **Expanded Data Source Support**: Ability to pull data from SQL databases and REST APIs.
|
268
|
+
- **Enhanced Deployment Options**: More flexible deployment strategies and integration with CI/CD pipelines.
|
269
|
+
- **Advanced Monitoring and Logging**: Improved tools for monitoring model performance and logging.
|
270
|
+
- **User Interface Improvements**: Enhanced UI components for managing models and datasets.
|
data/Rakefile
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
require_relative "../../../lib/easy_ml/core/model"
|
2
|
+
module EasyML
|
3
|
+
class Model < ActiveRecord::Base
|
4
|
+
include EasyML::Core::ModelCore
|
5
|
+
|
6
|
+
self.table_name = "easy_ml_models"
|
7
|
+
|
8
|
+
scope :live, -> { where(is_live: true) }
|
9
|
+
attribute :root_dir, :string
|
10
|
+
after_initialize :apply_defaults
|
11
|
+
|
12
|
+
validate :only_one_model_is_live?
|
13
|
+
def only_one_model_is_live?
|
14
|
+
return if @marking_live
|
15
|
+
|
16
|
+
if previous_versions.live.count > 1
|
17
|
+
raise "Multiple previous versions of #{name} are live! This should never happen. Update previous versions to is_live=false before proceeding"
|
18
|
+
end
|
19
|
+
|
20
|
+
return unless previous_versions.live.any? && is_live
|
21
|
+
|
22
|
+
errors.add(:is_live,
|
23
|
+
"cannot mark model live when previous version is live. Explicitly use the mark_live method to mark this as the live version")
|
24
|
+
end
|
25
|
+
|
26
|
+
def mark_live
|
27
|
+
transaction do
|
28
|
+
self.class.where(name: name).where.not(id: id).update_all(is_live: false)
|
29
|
+
self.class.where(id: id).update_all(is_live: true)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def previous_versions
|
34
|
+
EasyML::Model.where(name: name).order(id: :desc)
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def files_to_keep
|
40
|
+
live_models = self.class.live
|
41
|
+
|
42
|
+
recent_copies = live_models.flat_map do |live|
|
43
|
+
# Fetch all models with the same name
|
44
|
+
self.class.where(name: live.name).where(is_live: false).order(created_at: :desc).limit(live.name == name ? 4 : 5)
|
45
|
+
end
|
46
|
+
|
47
|
+
recent_versions = self.class
|
48
|
+
.where.not(
|
49
|
+
"EXISTS (SELECT 1 FROM easy_ml_models e2 WHERE e2.name = easy_ml_models.name AND e2.is_live = true)"
|
50
|
+
)
|
51
|
+
.where("created_at >= ?", 2.days.ago)
|
52
|
+
.order(created_at: :desc)
|
53
|
+
.group_by(&:name)
|
54
|
+
.flat_map { |_, models| models.take(5) }
|
55
|
+
|
56
|
+
([self] + recent_versions + recent_copies + live_models).compact.map(&:file).map(&:path).uniq
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
require "carrierwave"
|
2
|
+
require_relative "model_core"
|
3
|
+
require_relative "uploaders/model_uploader"
|
4
|
+
|
5
|
+
module EasyML
|
6
|
+
module Core
|
7
|
+
class Model
|
8
|
+
include GlueGun::DSL
|
9
|
+
|
10
|
+
attribute :name, :string
|
11
|
+
attribute :version, :string
|
12
|
+
attribute :task, :string, default: "regression"
|
13
|
+
attribute :metrics, :array
|
14
|
+
attribute :ml_model, :string
|
15
|
+
attribute :file, :string
|
16
|
+
attribute :root_dir, :string
|
17
|
+
attribute :objective
|
18
|
+
attribute :evaluator
|
19
|
+
attribute :evaluator_metric
|
20
|
+
|
21
|
+
include EasyML::Core::ModelCore
|
22
|
+
|
23
|
+
def initialize(options = {})
|
24
|
+
super
|
25
|
+
apply_defaults
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,181 @@
|
|
1
|
+
require "carrierwave"
|
2
|
+
require_relative "uploaders/model_uploader"
|
3
|
+
|
4
|
+
module EasyML
|
5
|
+
module Core
|
6
|
+
module ModelCore
|
7
|
+
attr_accessor :dataset
|
8
|
+
|
9
|
+
def self.included(base)
|
10
|
+
base.send(:include, GlueGun::DSL)
|
11
|
+
base.send(:extend, CarrierWave::Mount)
|
12
|
+
base.send(:mount_uploader, :file, EasyML::Core::Uploaders::ModelUploader)
|
13
|
+
|
14
|
+
base.class_eval do
|
15
|
+
validates :task, inclusion: { in: %w[regression classification] }
|
16
|
+
validates :task, presence: true
|
17
|
+
validate :dataset_is_a_dataset?
|
18
|
+
validate :validate_any_metrics?
|
19
|
+
validate :validate_metrics_for_task
|
20
|
+
before_validation :save_model_file, if: -> { fit? }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def dataset_is_a_dataset?
|
25
|
+
return if dataset.nil?
|
26
|
+
return if dataset.class.ancestors.include?(EasyML::Data::Dataset)
|
27
|
+
|
28
|
+
errors.add(:dataset, "Must be a subclass of EasyML::Dataset")
|
29
|
+
end
|
30
|
+
|
31
|
+
def validate_any_metrics?
|
32
|
+
return if metrics.any?
|
33
|
+
|
34
|
+
errors.add(:metrics, "Must include at least one metric. Allowed metrics are #{allowed_metrics.join(", ")}")
|
35
|
+
end
|
36
|
+
|
37
|
+
def validate_metrics_for_task
|
38
|
+
nonsensical_metrics = metrics.select do |metric|
|
39
|
+
allowed_metrics.exclude?(metric)
|
40
|
+
end
|
41
|
+
|
42
|
+
return unless nonsensical_metrics.any?
|
43
|
+
|
44
|
+
errors.add(:metrics,
|
45
|
+
"cannot use metrics: #{nonsensical_metrics.join(", ")} for task #{task}. Allowed metrics are: #{allowed_metrics.join(", ")}")
|
46
|
+
end
|
47
|
+
|
48
|
+
def fit(x_train: nil, y_train: nil, x_valid: nil, y_valid: nil)
|
49
|
+
if x_train.nil?
|
50
|
+
dataset.refresh!
|
51
|
+
train_in_batches
|
52
|
+
else
|
53
|
+
train(x_train, y_train, x_valid, y_valid)
|
54
|
+
end
|
55
|
+
@is_fit = true
|
56
|
+
end
|
57
|
+
|
58
|
+
def decode_labels(ys, col: nil)
|
59
|
+
dataset.decode_labels(ys, col: col)
|
60
|
+
end
|
61
|
+
|
62
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
63
|
+
evaluator ||= self.evaluator
|
64
|
+
EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true,
|
65
|
+
evaluator: evaluator)
|
66
|
+
end
|
67
|
+
|
68
|
+
def predict(xs)
|
69
|
+
raise NotImplementedError, "Subclasses must implement predict method"
|
70
|
+
end
|
71
|
+
|
72
|
+
def load
|
73
|
+
raise NotImplementedError, "Subclasses must implement load method"
|
74
|
+
end
|
75
|
+
|
76
|
+
def _save_model_file
|
77
|
+
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
78
|
+
end
|
79
|
+
|
80
|
+
def save
|
81
|
+
super if defined?(super) && self.class.superclass.method_defined?(:save)
|
82
|
+
save_model_file
|
83
|
+
end
|
84
|
+
|
85
|
+
def save_model_file
|
86
|
+
raise "No trained model! Need to train model before saving (call model.fit)" unless fit?
|
87
|
+
|
88
|
+
path = File.join(model_dir, "#{version}.json")
|
89
|
+
ensure_directory_exists(File.dirname(path))
|
90
|
+
|
91
|
+
_save_model_file(path)
|
92
|
+
|
93
|
+
File.open(path) do |f|
|
94
|
+
self.file = f
|
95
|
+
end
|
96
|
+
file.store!
|
97
|
+
|
98
|
+
cleanup
|
99
|
+
end
|
100
|
+
|
101
|
+
def get_params
|
102
|
+
@hyperparameters.to_h
|
103
|
+
end
|
104
|
+
|
105
|
+
def allowed_metrics
|
106
|
+
return [] unless task.present?
|
107
|
+
|
108
|
+
case task.to_sym
|
109
|
+
when :regression
|
110
|
+
%w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
|
111
|
+
when :classification
|
112
|
+
%w[accuracy_score precision_score recall_score f1_score auc roc_auc]
|
113
|
+
else
|
114
|
+
[]
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
def cleanup!
|
119
|
+
[file_dir, model_dir].each do |dir|
|
120
|
+
EasyML::FileRotate.new(dir, []).cleanup(extension_allowlist)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
def cleanup
|
125
|
+
[file_dir, model_dir].each do |dir|
|
126
|
+
EasyML::FileRotate.new(dir, files_to_keep).cleanup(extension_allowlist)
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
def fit?
|
131
|
+
@is_fit == true
|
132
|
+
end
|
133
|
+
|
134
|
+
private
|
135
|
+
|
136
|
+
def file_dir
|
137
|
+
return unless file.path.present?
|
138
|
+
|
139
|
+
File.dirname(file.path).split("/")[0..-2].join("/")
|
140
|
+
end
|
141
|
+
|
142
|
+
def extension_allowlist
|
143
|
+
EasyML::Core::Uploaders::ModelUploader.new.extension_allowlist
|
144
|
+
end
|
145
|
+
|
146
|
+
def _save_model_file(path = nil)
|
147
|
+
raise NotImplementedError, "Subclasses must implement _save_model_file method"
|
148
|
+
end
|
149
|
+
|
150
|
+
def ensure_directory_exists(dir)
|
151
|
+
FileUtils.mkdir_p(dir) unless File.directory?(dir)
|
152
|
+
end
|
153
|
+
|
154
|
+
def apply_defaults
|
155
|
+
self.version ||= generate_version_string
|
156
|
+
self.metrics ||= allowed_metrics
|
157
|
+
self.ml_model ||= get_ml_model
|
158
|
+
end
|
159
|
+
|
160
|
+
def get_ml_model
|
161
|
+
self.class.name.split("::").last.underscore
|
162
|
+
end
|
163
|
+
|
164
|
+
def generate_version_string
|
165
|
+
timestamp = Time.now.utc.strftime("%Y%m%d%H%M%S")
|
166
|
+
model_name = self.class.name.split("::").last.underscore
|
167
|
+
"#{model_name}_#{timestamp}"
|
168
|
+
end
|
169
|
+
|
170
|
+
def model_dir
|
171
|
+
File.join(root_dir, "easy_ml_models", name.present? ? name.split.join.underscore : "")
|
172
|
+
end
|
173
|
+
|
174
|
+
def files_to_keep
|
175
|
+
Dir.glob(File.join(file_dir, "*")).select { |f| File.file?(f) }.sort_by do |filename|
|
176
|
+
Time.parse(filename.split("/").last.gsub(/\D/, ""))
|
177
|
+
end.reverse.take(5)
|
178
|
+
end
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|