easy_ml 0.2.0.pre.rc71 → 0.2.0.pre.rc75
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/app/controllers/easy_ml/datasets_controller.rb +33 -0
- data/app/controllers/easy_ml/datasources_controller.rb +7 -0
- data/app/controllers/easy_ml/models_controller.rb +46 -0
- data/app/frontend/components/DatasetCard.tsx +212 -0
- data/app/frontend/components/ModelCard.tsx +114 -29
- data/app/frontend/components/StackTrace.tsx +13 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +10 -7
- data/app/frontend/components/datasets/UploadDatasetButton.tsx +51 -0
- data/app/frontend/components/models/DownloadModelModal.tsx +90 -0
- data/app/frontend/components/models/UploadModelModal.tsx +212 -0
- data/app/frontend/components/models/index.ts +2 -0
- data/app/frontend/pages/DatasetsPage.tsx +36 -130
- data/app/frontend/pages/DatasourcesPage.tsx +22 -2
- data/app/frontend/pages/ModelsPage.tsx +37 -11
- data/app/frontend/types/dataset.ts +1 -2
- data/app/frontend/types.ts +1 -1
- data/app/jobs/easy_ml/reaper.rb +55 -0
- data/app/jobs/easy_ml/training_job.rb +1 -1
- data/app/models/easy_ml/column/imputers/base.rb +4 -0
- data/app/models/easy_ml/column/imputers/clip.rb +5 -3
- data/app/models/easy_ml/column/imputers/imputer.rb +11 -13
- data/app/models/easy_ml/column/imputers/mean.rb +7 -3
- data/app/models/easy_ml/column/imputers/null_imputer.rb +3 -0
- data/app/models/easy_ml/column/imputers/ordinal_encoder.rb +5 -1
- data/app/models/easy_ml/column/imputers.rb +3 -1
- data/app/models/easy_ml/column/lineage/base.rb +5 -1
- data/app/models/easy_ml/column/lineage/computed_by_feature.rb +1 -1
- data/app/models/easy_ml/column/lineage/preprocessed.rb +1 -1
- data/app/models/easy_ml/column/lineage/raw_dataset.rb +1 -1
- data/app/models/easy_ml/column/selector.rb +4 -0
- data/app/models/easy_ml/column.rb +79 -63
- data/app/models/easy_ml/column_history.rb +28 -28
- data/app/models/easy_ml/column_list/imputer.rb +23 -0
- data/app/models/easy_ml/column_list.rb +39 -26
- data/app/models/easy_ml/dataset/learner/base.rb +34 -0
- data/app/models/easy_ml/dataset/learner/eager/boolean.rb +10 -0
- data/app/models/easy_ml/dataset/learner/eager/categorical.rb +51 -0
- data/app/models/easy_ml/dataset/learner/eager/query.rb +37 -0
- data/app/models/easy_ml/dataset/learner/eager.rb +43 -0
- data/app/models/easy_ml/dataset/learner/lazy/boolean.rb +13 -0
- data/app/models/easy_ml/dataset/learner/lazy/categorical.rb +10 -0
- data/app/models/easy_ml/dataset/learner/lazy/datetime.rb +19 -0
- data/app/models/easy_ml/dataset/learner/lazy/null.rb +17 -0
- data/app/models/easy_ml/dataset/learner/lazy/numeric.rb +19 -0
- data/app/models/easy_ml/dataset/learner/lazy/query.rb +69 -0
- data/app/models/easy_ml/dataset/learner/lazy/string.rb +19 -0
- data/app/models/easy_ml/dataset/learner/lazy.rb +51 -0
- data/app/models/easy_ml/dataset/learner/query.rb +25 -0
- data/app/models/easy_ml/dataset/learner.rb +100 -0
- data/app/models/easy_ml/dataset.rb +150 -36
- data/app/models/easy_ml/dataset_history.rb +1 -0
- data/app/models/easy_ml/datasource.rb +9 -0
- data/app/models/easy_ml/event.rb +5 -7
- data/app/models/easy_ml/export/column.rb +27 -0
- data/app/models/easy_ml/export/dataset.rb +37 -0
- data/app/models/easy_ml/export/datasource.rb +12 -0
- data/app/models/easy_ml/export/feature.rb +24 -0
- data/app/models/easy_ml/export/model.rb +40 -0
- data/app/models/easy_ml/export/retraining_job.rb +20 -0
- data/app/models/easy_ml/export/splitter.rb +14 -0
- data/app/models/easy_ml/feature.rb +21 -0
- data/app/models/easy_ml/import/column.rb +35 -0
- data/app/models/easy_ml/import/dataset.rb +148 -0
- data/app/models/easy_ml/import/feature.rb +36 -0
- data/app/models/easy_ml/import/model.rb +136 -0
- data/app/models/easy_ml/import/retraining_job.rb +29 -0
- data/app/models/easy_ml/import/splitter.rb +34 -0
- data/app/models/easy_ml/lineage.rb +44 -0
- data/app/models/easy_ml/model.rb +101 -37
- data/app/models/easy_ml/model_file.rb +6 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +7 -7
- data/app/models/easy_ml/models/xgboost.rb +33 -9
- data/app/models/easy_ml/retraining_job.rb +8 -1
- data/app/models/easy_ml/retraining_run.rb +7 -5
- data/app/models/easy_ml/splitter.rb +8 -0
- data/app/models/lineage_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +7 -1
- data/app/serializers/easy_ml/dataset_serializer.rb +2 -1
- data/app/serializers/easy_ml/lineage_serializer.rb +9 -0
- data/config/routes.rb +14 -1
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +3 -3
- data/lib/easy_ml/core/tuner.rb +13 -12
- data/lib/easy_ml/data/polars_column.rb +149 -100
- data/lib/easy_ml/data/polars_reader.rb +8 -5
- data/lib/easy_ml/data/polars_schema.rb +56 -0
- data/lib/easy_ml/data/splits/file_split.rb +20 -2
- data/lib/easy_ml/data/splits/split.rb +10 -1
- data/lib/easy_ml/data.rb +1 -0
- data/lib/easy_ml/deep_compact.rb +19 -0
- data/lib/easy_ml/engine.rb +1 -0
- data/lib/easy_ml/feature_store.rb +2 -6
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +6 -0
- data/lib/easy_ml/railtie/templates/migration/add_extra_metadata_to_columns.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/add_raw_schema_to_datasets.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/add_unique_constraint_to_easy_ml_model_names.rb.tt +8 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_lineages.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/remove_evaluator_from_retraining_jobs.rb.tt +7 -0
- data/lib/easy_ml/railtie/templates/migration/update_preprocessing_steps_to_jsonb.rb.tt +18 -0
- data/lib/easy_ml/timing.rb +34 -0
- data/lib/easy_ml/version.rb +1 -1
- data/lib/easy_ml.rb +2 -0
- data/public/easy_ml/assets/.vite/manifest.json +2 -2
- data/public/easy_ml/assets/assets/Application-Q7L6ioxr.css +1 -0
- data/public/easy_ml/assets/assets/entrypoints/Application.tsx-Rrzo4ecT.js +522 -0
- data/public/easy_ml/assets/assets/entrypoints/Application.tsx-Rrzo4ecT.js.map +1 -0
- metadata +53 -12
- data/app/models/easy_ml/column/learners/base.rb +0 -103
- data/app/models/easy_ml/column/learners/boolean.rb +0 -11
- data/app/models/easy_ml/column/learners/categorical.rb +0 -51
- data/app/models/easy_ml/column/learners/datetime.rb +0 -19
- data/app/models/easy_ml/column/learners/null.rb +0 -22
- data/app/models/easy_ml/column/learners/numeric.rb +0 -33
- data/app/models/easy_ml/column/learners/string.rb +0 -15
- data/public/easy_ml/assets/assets/Application-BbFobaXt.css +0 -1
- data/public/easy_ml/assets/assets/entrypoints/Application.tsx-CibZcrBc.js +0 -489
- data/public/easy_ml/assets/assets/entrypoints/Application.tsx-CibZcrBc.js.map +0 -1
@@ -32,9 +32,9 @@ module EasyML
|
|
32
32
|
false
|
33
33
|
end
|
34
34
|
|
35
|
-
def
|
35
|
+
def valid_dataset
|
36
36
|
if tuner.present?
|
37
|
-
[tuner.
|
37
|
+
[tuner.x_valid, tuner.y_valid]
|
38
38
|
else
|
39
39
|
model.dataset.valid(split_ys: true)
|
40
40
|
end
|
@@ -46,12 +46,12 @@ module EasyML
|
|
46
46
|
log_frequency = 10
|
47
47
|
if epoch % log_frequency == 0
|
48
48
|
model.adapter.external_model = booster
|
49
|
-
|
50
|
-
@preprocessed ||= model.preprocess(
|
49
|
+
x_valid, y_valid = valid_dataset
|
50
|
+
@preprocessed ||= model.preprocess(x_valid)
|
51
51
|
y_pred = model.predict(@preprocessed)
|
52
|
-
dataset = model.dataset.
|
52
|
+
dataset = model.dataset.valid(all_columns: true)
|
53
53
|
|
54
|
-
metrics = model.evaluate(y_pred: y_pred, y_true:
|
54
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_valid, x_true: x_valid, dataset: dataset)
|
55
55
|
Wandb.log(metrics)
|
56
56
|
end
|
57
57
|
|
@@ -67,7 +67,7 @@ module EasyML
|
|
67
67
|
def after_training(booster)
|
68
68
|
return booster unless wandb_enabled?
|
69
69
|
|
70
|
-
if model.last_run&.wandb_url.nil?
|
70
|
+
if model.last_run.present? && model.last_run&.wandb_url.nil?
|
71
71
|
if tuner.present? && !tuner.current_run.wandb_url.present?
|
72
72
|
tuner.current_run.wandb_url = Wandb.current_run.url
|
73
73
|
end
|
@@ -199,7 +199,7 @@ module EasyML
|
|
199
199
|
set_default_wandb_project_name unless tuning
|
200
200
|
|
201
201
|
# Prepare validation data
|
202
|
-
x_valid, y_valid = dataset.valid(split_ys: true)
|
202
|
+
x_valid, y_valid = dataset.valid(split_ys: true, select: dataset.col_order)
|
203
203
|
d_valid = preprocess(x_valid, y_valid)
|
204
204
|
|
205
205
|
num_iterations = hyperparameters.to_h[:n_estimators]
|
@@ -217,7 +217,7 @@ module EasyML
|
|
217
217
|
callbacks << ::XGBoost::EvaluationMonitor.new(period: 1)
|
218
218
|
|
219
219
|
# Generate batches without loading full dataset
|
220
|
-
batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key)
|
220
|
+
batches = dataset.train(split_ys: true, batch_size: batch_size, batch_start: batch_start, batch_key: batch_key, select: dataset.col_order)
|
221
221
|
prev_xs = []
|
222
222
|
prev_ys = []
|
223
223
|
|
@@ -281,9 +281,32 @@ module EasyML
|
|
281
281
|
return @booster
|
282
282
|
end
|
283
283
|
|
284
|
-
def weights
|
285
|
-
|
286
|
-
|
284
|
+
def weights(model_file)
|
285
|
+
return nil unless model_file.present? && model_file.fit?
|
286
|
+
|
287
|
+
JSON.parse(model_file.read)
|
288
|
+
end
|
289
|
+
|
290
|
+
def set_weights(model_file, weights)
|
291
|
+
raise ArgumentError, "Weights must be provided" unless weights.present?
|
292
|
+
|
293
|
+
# Create a temp file with the weights
|
294
|
+
temp_file = Tempfile.new(["xgboost_weights", ".json"])
|
295
|
+
begin
|
296
|
+
temp_file.write(weights.to_json)
|
297
|
+
temp_file.close
|
298
|
+
|
299
|
+
# Load the weights into a new booster
|
300
|
+
initialize_model do
|
301
|
+
attrs = {
|
302
|
+
params: hyperparameters.to_h.symbolize_keys.compact,
|
303
|
+
model_file: temp_file.path,
|
304
|
+
}.compact
|
305
|
+
booster_class.new(**attrs)
|
306
|
+
end
|
307
|
+
ensure
|
308
|
+
temp_file.unlink
|
309
|
+
end
|
287
310
|
end
|
288
311
|
|
289
312
|
def predict(xs)
|
@@ -397,11 +420,12 @@ module EasyML
|
|
397
420
|
|
398
421
|
def prepare_data
|
399
422
|
if @d_train.nil?
|
400
|
-
|
423
|
+
col_order = dataset.col_order
|
424
|
+
x_sample, y_sample = dataset.train(split_ys: true, limit: 5, select: col_order)
|
401
425
|
preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
|
402
|
-
x_train, y_train = dataset.train(split_ys: true)
|
403
|
-
x_valid, y_valid = dataset.valid(split_ys: true)
|
404
|
-
x_test, y_test = dataset.test(split_ys: true)
|
426
|
+
x_train, y_train = dataset.train(split_ys: true, select: col_order)
|
427
|
+
x_valid, y_valid = dataset.valid(split_ys: true, select: col_order)
|
428
|
+
x_test, y_test = dataset.test(split_ys: true, select: col_order)
|
405
429
|
@d_train = preprocess(x_train, y_train)
|
406
430
|
@d_valid = preprocess(x_valid, y_valid)
|
407
431
|
@d_test = preprocess(x_test, y_test)
|
@@ -6,7 +6,6 @@
|
|
6
6
|
# model_id :bigint
|
7
7
|
# frequency :string not null
|
8
8
|
# at :json not null
|
9
|
-
# evaluator :json
|
10
9
|
# tuning_enabled :boolean default(FALSE)
|
11
10
|
# tuner_config :json
|
12
11
|
# tuning_frequency :string
|
@@ -160,6 +159,14 @@ module EasyML
|
|
160
159
|
}[frequency.to_sym]
|
161
160
|
end
|
162
161
|
|
162
|
+
def to_config
|
163
|
+
EasyML::Export::RetrainingJob.to_config(self)
|
164
|
+
end
|
165
|
+
|
166
|
+
def self.from_config(config, model)
|
167
|
+
EasyML::Import::RetrainingJob.from_config(config, model)
|
168
|
+
end
|
169
|
+
|
163
170
|
private
|
164
171
|
|
165
172
|
def metric_class
|
@@ -37,7 +37,7 @@ module EasyML
|
|
37
37
|
belongs_to :model_file, class_name: "EasyML::ModelFile", optional: true
|
38
38
|
has_many :events, as: :eventable, class_name: "EasyML::Event", dependent: :destroy
|
39
39
|
|
40
|
-
validates :status, presence: true, inclusion: { in: %w[pending running success failed deployed] }
|
40
|
+
validates :status, presence: true, inclusion: { in: %w[pending running success failed deployed aborted] }
|
41
41
|
|
42
42
|
scope :running, -> { where(status: "running") }
|
43
43
|
|
@@ -83,7 +83,6 @@ module EasyML
|
|
83
83
|
completed_at: failed_reasons.none? ? Time.current : nil,
|
84
84
|
error_message: failed_reasons.any? ? failed_reasons&.first : nil,
|
85
85
|
model: training_model,
|
86
|
-
metrics: training_model.evaluate,
|
87
86
|
best_params: best_params,
|
88
87
|
tuner_job_id: tuner&.id,
|
89
88
|
metadata: tuner&.metadata,
|
@@ -109,6 +108,7 @@ module EasyML
|
|
109
108
|
end
|
110
109
|
true
|
111
110
|
rescue => e
|
111
|
+
puts EasyML::Event.easy_ml_context(e.backtrace)
|
112
112
|
EasyML::Event.handle_error(self, e)
|
113
113
|
update!(
|
114
114
|
status: "failed",
|
@@ -150,14 +150,15 @@ module EasyML
|
|
150
150
|
|
151
151
|
training_model.dataset.refresh
|
152
152
|
evaluator = retraining_job.evaluator.symbolize_keys
|
153
|
-
|
154
|
-
y_pred = training_model.predict(
|
153
|
+
x_test, y_test = training_model.dataset.test(split_ys: true)
|
154
|
+
y_pred = training_model.predict(x_test)
|
155
155
|
|
156
156
|
metric = evaluator[:metric].to_sym
|
157
157
|
metrics = EasyML::Core::ModelEvaluator.evaluate(
|
158
158
|
model: training_model,
|
159
159
|
y_pred: y_pred,
|
160
|
-
y_true:
|
160
|
+
y_true: y_test,
|
161
|
+
x_true: x_test,
|
161
162
|
dataset: training_model.dataset.test(all_columns: true),
|
162
163
|
evaluator: evaluator,
|
163
164
|
)
|
@@ -176,6 +177,7 @@ module EasyML
|
|
176
177
|
|
177
178
|
{
|
178
179
|
metric_value: metric_value,
|
180
|
+
metrics: metrics,
|
179
181
|
threshold: threshold,
|
180
182
|
threshold_direction: threshold_direction,
|
181
183
|
deployable: deployable,
|
@@ -75,6 +75,14 @@ module EasyML
|
|
75
75
|
}
|
76
76
|
end
|
77
77
|
|
78
|
+
def to_config
|
79
|
+
EasyML::Export::Splitter.to_config(self)
|
80
|
+
end
|
81
|
+
|
82
|
+
def self.from_config(config, dataset)
|
83
|
+
EasyML::Import::Splitter.from_config(config, dataset)
|
84
|
+
end
|
85
|
+
|
78
86
|
def split(df, &block)
|
79
87
|
adapter.split(df, &block)
|
80
88
|
end
|
@@ -28,10 +28,16 @@ module EasyML
|
|
28
28
|
|
29
29
|
attributes :id, :name, :description, :dataset_id, :datatype, :polars_datatype, :preprocessing_steps,
|
30
30
|
:hidden, :drop_if_null, :sample_values, :statistics, :is_target,
|
31
|
-
:is_computed, :computed_by
|
31
|
+
:is_computed, :computed_by
|
32
32
|
|
33
33
|
attribute :required do |object|
|
34
34
|
object.required?
|
35
35
|
end
|
36
|
+
|
37
|
+
attribute :lineage do |column|
|
38
|
+
column.lineages.map do |lineage|
|
39
|
+
LineageSerializer.new(lineage).serializable_hash.dig(:data, :attributes)
|
40
|
+
end
|
41
|
+
end
|
36
42
|
end
|
37
43
|
end
|
@@ -59,7 +59,8 @@ module EasyML
|
|
59
59
|
end
|
60
60
|
|
61
61
|
attribute :columns do |dataset|
|
62
|
-
dataset.
|
62
|
+
col_order = dataset.col_order
|
63
|
+
dataset.columns.sort_by { |c| col_order.index(c.name) || Float::INFINITY }.map do |column|
|
63
64
|
ColumnSerializer.new(column).serializable_hash.dig(:data, :attributes)
|
64
65
|
end
|
65
66
|
end
|
data/config/routes.rb
CHANGED
@@ -17,10 +17,16 @@ EasyML::Engine.routes.draw do
|
|
17
17
|
resources :models, as: :easy_ml_models do
|
18
18
|
member do
|
19
19
|
post :train
|
20
|
+
post :abort
|
21
|
+
get :download
|
22
|
+
post :upload
|
20
23
|
get :retraining_runs, to: "retraining_runs#index"
|
21
24
|
end
|
25
|
+
collection do
|
26
|
+
get "new", as: "new"
|
27
|
+
post :upload
|
28
|
+
end
|
22
29
|
resources :deploys, only: [:create]
|
23
|
-
get "new", on: :collection, as: "new"
|
24
30
|
end
|
25
31
|
|
26
32
|
resources :retraining_runs, only: [:show]
|
@@ -29,6 +35,7 @@ EasyML::Engine.routes.draw do
|
|
29
35
|
resources :datasources, as: :easy_ml_datasources do
|
30
36
|
member do
|
31
37
|
post :sync
|
38
|
+
post :abort
|
32
39
|
end
|
33
40
|
end
|
34
41
|
|
@@ -36,6 +43,12 @@ EasyML::Engine.routes.draw do
|
|
36
43
|
resources :datasets, as: :easy_ml_datasets do
|
37
44
|
member do
|
38
45
|
post :refresh
|
46
|
+
post :abort
|
47
|
+
get :download
|
48
|
+
post :upload
|
49
|
+
end
|
50
|
+
collection do
|
51
|
+
post :upload
|
39
52
|
end
|
40
53
|
end
|
41
54
|
|
@@ -4,7 +4,7 @@ module EasyML
|
|
4
4
|
module Adapters
|
5
5
|
class BaseAdapter
|
6
6
|
attr_accessor :config, :project_name, :tune_started_at, :model,
|
7
|
-
:
|
7
|
+
:x_valid, :y_valid, :metadata, :model
|
8
8
|
|
9
9
|
def initialize(options = {})
|
10
10
|
@model = options[:model]
|
@@ -12,8 +12,8 @@ module EasyML
|
|
12
12
|
@project_name = options[:project_name]
|
13
13
|
@tune_started_at = options[:tune_started_at]
|
14
14
|
@model = options[:model]
|
15
|
-
@
|
16
|
-
@
|
15
|
+
@x_valid = options[:x_valid]
|
16
|
+
@y_valid = options[:y_valid]
|
17
17
|
@metadata = options[:metadata] || {}
|
18
18
|
end
|
19
19
|
|
data/lib/easy_ml/core/tuner.rb
CHANGED
@@ -6,7 +6,7 @@ module EasyML
|
|
6
6
|
class Tuner
|
7
7
|
attr_accessor :model, :dataset, :project_name, :task, :config,
|
8
8
|
:metrics, :objective, :n_trials, :direction, :evaluator,
|
9
|
-
:study, :results, :adapter, :tune_started_at, :
|
9
|
+
:study, :results, :adapter, :tune_started_at, :x_valid, :y_valid,
|
10
10
|
:project_name, :job, :current_run, :trial_enumerator, :progress_block,
|
11
11
|
:tuner_job, :dataset
|
12
12
|
|
@@ -34,7 +34,7 @@ module EasyML
|
|
34
34
|
config: config,
|
35
35
|
project_name: project_name,
|
36
36
|
tune_started_at: nil, # This will be set during tune
|
37
|
-
|
37
|
+
y_valid: nil, # This will be set during tune
|
38
38
|
)
|
39
39
|
end
|
40
40
|
end
|
@@ -70,17 +70,16 @@ module EasyML
|
|
70
70
|
@job = tuner_job
|
71
71
|
@study = Optuna::Study.new(direction: direction)
|
72
72
|
@results = []
|
73
|
-
model.evaluator = evaluator if evaluator.present?
|
74
73
|
model.task = task
|
75
74
|
|
76
|
-
model.dataset.refresh
|
77
|
-
|
78
|
-
self.
|
79
|
-
self.
|
80
|
-
self.dataset = model.dataset.
|
75
|
+
model.dataset.refresh if model.dataset.needs_refresh?
|
76
|
+
x_valid, y_valid = model.dataset.valid(split_ys: true, select: model.dataset.col_order)
|
77
|
+
self.x_valid = x_valid
|
78
|
+
self.y_valid = y_valid
|
79
|
+
self.dataset = model.dataset.valid(all_columns: true)
|
81
80
|
adapter.tune_started_at = tune_started_at
|
82
|
-
adapter.
|
83
|
-
adapter.
|
81
|
+
adapter.x_valid = x_valid
|
82
|
+
adapter.y_valid = y_valid
|
84
83
|
|
85
84
|
model.prepare_data unless model.batch_mode
|
86
85
|
model.prepare_callbacks(self)
|
@@ -99,6 +98,7 @@ module EasyML
|
|
99
98
|
@results.push(result)
|
100
99
|
@study.tell(@current_trial, result)
|
101
100
|
rescue StandardError => e
|
101
|
+
puts EasyML::Event.easy_ml_context(e.backtrace)
|
102
102
|
@tuner_run.update!(status: :failed, hyperparameters: {})
|
103
103
|
puts "Optuna failed with: #{e.message}"
|
104
104
|
raise e
|
@@ -118,6 +118,7 @@ module EasyML
|
|
118
118
|
|
119
119
|
best_run&.hyperparameters
|
120
120
|
rescue StandardError => e
|
121
|
+
puts EasyML::Event.easy_ml_context(e.backtrace)
|
121
122
|
tuner_job&.update!(status: :failed, completed_at: Time.current)
|
122
123
|
raise e
|
123
124
|
end
|
@@ -137,9 +138,9 @@ module EasyML
|
|
137
138
|
end
|
138
139
|
end
|
139
140
|
|
140
|
-
y_pred = model.predict(
|
141
|
+
y_pred = model.predict(x_valid)
|
141
142
|
model.metrics = metrics
|
142
|
-
metrics = model.evaluate(y_pred: y_pred, y_true:
|
143
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_valid, x_true: x_valid, dataset: dataset)
|
143
144
|
metric = metrics.symbolize_keys.dig(model.evaluator[:metric].to_sym)
|
144
145
|
|
145
146
|
puts metrics
|
@@ -2,7 +2,7 @@ require_relative "date_converter"
|
|
2
2
|
|
3
3
|
module EasyML
|
4
4
|
module Data
|
5
|
-
|
5
|
+
class PolarsColumn
|
6
6
|
TYPE_MAP = {
|
7
7
|
float: Polars::Float64,
|
8
8
|
integer: Polars::Int64,
|
@@ -14,132 +14,181 @@ module EasyML
|
|
14
14
|
categorical: Polars::Categorical,
|
15
15
|
null: Polars::Null,
|
16
16
|
}
|
17
|
-
POLARS_MAP =
|
17
|
+
POLARS_MAP = {
|
18
|
+
Polars::Float64 => :float,
|
19
|
+
Polars::Int64 => :integer,
|
20
|
+
Polars::Float32 => :float,
|
21
|
+
Polars::Int32 => :integer,
|
22
|
+
Polars::Boolean => :boolean,
|
23
|
+
Polars::Datetime => :datetime,
|
24
|
+
Polars::Date => :date,
|
25
|
+
Polars::String => :string,
|
26
|
+
Polars::Categorical => :categorical,
|
27
|
+
Polars::Null => :null,
|
28
|
+
}.stringify_keys
|
29
|
+
include EasyML::Timing
|
30
|
+
|
18
31
|
class << self
|
19
32
|
def polars_to_sym(polars_type)
|
20
|
-
|
33
|
+
new.polars_to_sym(polars_type)
|
34
|
+
end
|
35
|
+
|
36
|
+
def determine_type(series, polars_type = false)
|
37
|
+
new.determine_type(series, polars_type)
|
21
38
|
end
|
22
39
|
|
23
40
|
def parse_polars_dtype(dtype_string)
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
raise ArgumentError, "Unknown Polars data type: #{dtype_string}"
|
34
|
-
end
|
41
|
+
new.parse_polars_dtype(dtype_string)
|
42
|
+
end
|
43
|
+
|
44
|
+
def get_polars_type(dtype)
|
45
|
+
new.get_polars_type(dtype)
|
46
|
+
end
|
47
|
+
|
48
|
+
def polars_dtype_to_sym(dtype_string)
|
49
|
+
new.polars_dtype_to_sym(dtype_string)
|
35
50
|
end
|
36
51
|
|
37
52
|
def sym_to_polars(symbol)
|
38
|
-
|
53
|
+
new.sym_to_polars(symbol)
|
39
54
|
end
|
55
|
+
end
|
40
56
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
if
|
48
|
-
|
49
|
-
|
50
|
-
date = EasyML::Data::DateConverter.maybe_convert_date(series)
|
51
|
-
return polars_type ? date[date.columns.first].dtype : :datetime
|
52
|
-
end
|
53
|
-
end
|
57
|
+
def polars_to_sym(polars_type)
|
58
|
+
return nil if polars_type.nil?
|
59
|
+
|
60
|
+
if polars_type.is_a?(Polars::DataType)
|
61
|
+
POLARS_MAP.dig(polars_type.class.to_s)
|
62
|
+
else
|
63
|
+
polars_type.to_sym if TYPE_MAP.keys.include?(polars_type.to_sym)
|
64
|
+
end
|
65
|
+
end
|
54
66
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
when Polars::Utf8
|
67
|
-
determine_string_type(series)
|
68
|
-
when Polars::Null
|
69
|
-
:null
|
70
|
-
else
|
71
|
-
:categorical
|
72
|
-
end
|
73
|
-
|
74
|
-
polars_type ? sym_to_polars(type_name) : type_name
|
67
|
+
def parse_polars_dtype(dtype_string)
|
68
|
+
case dtype_string
|
69
|
+
when /^Polars::Datetime/
|
70
|
+
time_unit = dtype_string[/time_unit: "(.*?)"/, 1]
|
71
|
+
time_zone = dtype_string[/time_zone: (.*)?\)/, 1]
|
72
|
+
time_zone = time_zone == "nil" ? nil : time_zone&.delete('"')
|
73
|
+
Polars::Datetime.new(time_unit, time_zone)
|
74
|
+
when /^Polars::/
|
75
|
+
Polars.const_get(dtype_string.split("::").last)
|
76
|
+
else
|
77
|
+
raise ArgumentError, "Unknown Polars data type: #{dtype_string}"
|
75
78
|
end
|
79
|
+
end
|
76
80
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
def sym_to_polars(symbol)
|
82
|
+
TYPE_MAP.dig(symbol.to_sym)
|
83
|
+
end
|
84
|
+
|
85
|
+
# Determines the semantic type of a field based on its data
|
86
|
+
# @param series [Polars::Series] The series to analyze
|
87
|
+
# @return [Symbol] One of :numeric, :datetime, :categorical, or :text
|
88
|
+
def determine_type(series, polars_type = false)
|
89
|
+
dtype = series.dtype
|
90
|
+
|
91
|
+
if dtype.is_a?(Polars::Utf8)
|
92
|
+
string_type = determine_string_type(series)
|
93
|
+
if string_type == :datetime
|
94
|
+
date = EasyML::Data::DateConverter.maybe_convert_date(series)
|
95
|
+
return polars_type ? date[date.columns.first].dtype : :datetime
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
type_name = case dtype
|
100
|
+
when Polars::Float64
|
101
|
+
:float
|
102
|
+
when Polars::Int64
|
103
|
+
:integer
|
104
|
+
when Polars::Datetime
|
83
105
|
:datetime
|
106
|
+
when Polars::Date
|
107
|
+
:date
|
108
|
+
when Polars::Boolean
|
109
|
+
:boolean
|
110
|
+
when Polars::Utf8
|
111
|
+
determine_string_type(series)
|
112
|
+
when Polars::Null
|
113
|
+
:null
|
84
114
|
else
|
85
|
-
|
115
|
+
:categorical
|
86
116
|
end
|
117
|
+
|
118
|
+
polars_type ? sym_to_polars(type_name) : type_name
|
119
|
+
end
|
120
|
+
|
121
|
+
measure_method_timing :determine_type
|
122
|
+
|
123
|
+
# Determines if a string field is a date, text, or categorical
|
124
|
+
# @param series [Polars::Series] The string series to analyze
|
125
|
+
# @return [Symbol] One of :datetime, :text, or :categorical
|
126
|
+
def determine_string_type(series)
|
127
|
+
if EasyML::Data::DateConverter.maybe_convert_date(Polars::DataFrame.new({ temp: series }),
|
128
|
+
:temp)[:temp].dtype.is_a?(Polars::Datetime)
|
129
|
+
:datetime
|
130
|
+
else
|
131
|
+
categorical_or_text?(series)
|
87
132
|
end
|
133
|
+
end
|
88
134
|
|
89
|
-
|
90
|
-
# @param series [Polars::Series] The string series to analyze
|
91
|
-
# @return [Symbol] Either :categorical or :text
|
92
|
-
def categorical_or_text?(series)
|
93
|
-
return :categorical if series.null_count == series.len
|
135
|
+
measure_method_timing :determine_string_type
|
94
136
|
|
95
|
-
|
96
|
-
|
97
|
-
|
137
|
+
# Determines if a string field is categorical or free text
|
138
|
+
# @param series [Polars::Series] The string series to analyze
|
139
|
+
# @return [Symbol] Either :categorical or :text
|
140
|
+
def categorical_or_text?(series)
|
141
|
+
return :categorical if series.null_count == series.len
|
98
142
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
(value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
|
103
|
-
)
|
143
|
+
# Get non-null count for percentage calculations
|
144
|
+
non_null_count = series.len - series.null_count
|
145
|
+
return :categorical if non_null_count == 0
|
104
146
|
|
105
|
-
|
106
|
-
|
107
|
-
|
147
|
+
# Get value counts as percentages
|
148
|
+
value_counts = series.value_counts(parallel: true)
|
149
|
+
percentages = value_counts.with_column(
|
150
|
+
(value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
|
151
|
+
)
|
108
152
|
|
109
|
-
|
110
|
-
|
153
|
+
# Check if any category represents more than 10% of the data
|
154
|
+
max_percentage = percentages["percentage"].max
|
155
|
+
return :text if max_percentage < 10.0
|
111
156
|
|
112
|
-
|
113
|
-
|
114
|
-
end
|
157
|
+
# Calculate average percentage per category
|
158
|
+
avg_percentage = 100.0 / series.n_unique
|
115
159
|
|
116
|
-
#
|
117
|
-
|
118
|
-
|
119
|
-
def numeric?(field_type)
|
120
|
-
field_type == :numeric
|
121
|
-
end
|
160
|
+
# If average category represents less than 1% of data, it's likely text
|
161
|
+
avg_percentage < 1.0 ? :text : :categorical
|
162
|
+
end
|
122
163
|
|
123
|
-
|
124
|
-
# @param field_type [Symbol] The field type to check
|
125
|
-
# @return [Boolean]
|
126
|
-
def categorical?(field_type)
|
127
|
-
field_type == :categorical
|
128
|
-
end
|
164
|
+
measure_method_timing :categorical_or_text?
|
129
165
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
166
|
+
# Returns whether the field type is numeric
|
167
|
+
# @param field_type [Symbol] The field type to check
|
168
|
+
# @return [Boolean]
|
169
|
+
def numeric?(field_type)
|
170
|
+
field_type == :numeric
|
171
|
+
end
|
136
172
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
173
|
+
# Returns whether the field type is categorical
|
174
|
+
# @param field_type [Symbol] The field type to check
|
175
|
+
# @return [Boolean]
|
176
|
+
def categorical?(field_type)
|
177
|
+
field_type == :categorical
|
178
|
+
end
|
179
|
+
|
180
|
+
# Returns whether the field type is datetime
|
181
|
+
# @param field_type [Symbol] The field type to check
|
182
|
+
# @return [Boolean]
|
183
|
+
def datetime?(field_type)
|
184
|
+
field_type == :datetime
|
185
|
+
end
|
186
|
+
|
187
|
+
# Returns whether the field type is text
|
188
|
+
# @param field_type [Symbol] The field type to check
|
189
|
+
# @return [Boolean]
|
190
|
+
def text?(field_type)
|
191
|
+
field_type == :text
|
143
192
|
end
|
144
193
|
end
|
145
194
|
end
|