easy_ml 0.1.4 → 0.2.0.pre.rc1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -5,44 +5,23 @@ module EasyML
|
|
5
5
|
class Tuner
|
6
6
|
module Adapters
|
7
7
|
class XGBoostAdapter < BaseAdapter
|
8
|
-
include GlueGun::DSL
|
9
|
-
|
10
8
|
def defaults
|
11
9
|
{
|
12
10
|
learning_rate: {
|
13
11
|
min: 0.001,
|
14
12
|
max: 0.1,
|
15
|
-
log: true
|
13
|
+
log: true,
|
16
14
|
},
|
17
15
|
n_estimators: {
|
18
16
|
min: 100,
|
19
|
-
max: 1_000
|
17
|
+
max: 1_000,
|
20
18
|
},
|
21
19
|
max_depth: {
|
22
20
|
min: 2,
|
23
|
-
max: 20
|
24
|
-
}
|
21
|
+
max: 20,
|
22
|
+
},
|
25
23
|
}
|
26
24
|
end
|
27
|
-
|
28
|
-
def configure_callbacks
|
29
|
-
model.customize_callbacks do |callbacks|
|
30
|
-
return unless callbacks.present?
|
31
|
-
|
32
|
-
wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
33
|
-
return unless wandb_callback.present?
|
34
|
-
|
35
|
-
wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
36
|
-
wandb_callback.custom_loggers = [
|
37
|
-
lambda do |booster, _epoch, _hist|
|
38
|
-
dtrain = model.send(:preprocess, x_true, y_true)
|
39
|
-
y_pred = booster.predict(dtrain)
|
40
|
-
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
41
|
-
Wandb.log(metrics)
|
42
|
-
end
|
43
|
-
]
|
44
|
-
end
|
45
|
-
end
|
46
25
|
end
|
47
26
|
end
|
48
27
|
end
|
data/lib/easy_ml/core/tuner.rb
CHANGED
@@ -4,34 +4,36 @@ require_relative "tuner/adapters"
|
|
4
4
|
module EasyML
|
5
5
|
module Core
|
6
6
|
class Tuner
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
opt.bind_attribute :tune_started_at
|
27
|
-
opt.bind_attribute :y_true
|
28
|
-
end
|
7
|
+
attr_accessor :model, :dataset, :project_name, :task, :config,
|
8
|
+
:metrics, :objective, :n_trials, :direction, :evaluator,
|
9
|
+
:study, :results, :adapter, :tune_started_at, :x_true, :y_true,
|
10
|
+
:project_name, :job, :current_run
|
11
|
+
|
12
|
+
def initialize(options = {})
|
13
|
+
@model = options[:model]
|
14
|
+
@dataset = options[:dataset]
|
15
|
+
@project_name = options[:project_name]
|
16
|
+
@task = options[:task]
|
17
|
+
@config = options[:config] || {}
|
18
|
+
@metrics = options[:metrics]
|
19
|
+
@objective = options[:objective]
|
20
|
+
@n_trials = options[:n_trials] || 100
|
21
|
+
@direction = EasyML::Core::ModelEvaluator.get(objective).new.direction
|
22
|
+
@evaluator = options[:evaluator]
|
23
|
+
@tune_started_at = EasyML::Support::UTC.now
|
24
|
+
@project_name = "#{@model.name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
25
|
+
end
|
29
26
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
27
|
+
def initialize_adapter
|
28
|
+
case model&.model_type
|
29
|
+
when "xgboost"
|
30
|
+
Adapters::XGBoostAdapter.new(
|
31
|
+
model: model,
|
32
|
+
config: config,
|
33
|
+
project_name: project_name,
|
34
|
+
tune_started_at: nil, # This will be set during tune
|
35
|
+
y_true: nil, # This will be set during tune
|
36
|
+
)
|
35
37
|
end
|
36
38
|
end
|
37
39
|
|
@@ -41,53 +43,112 @@ module EasyML
|
|
41
43
|
raise "Trial failed: Stopping optimization."
|
42
44
|
end
|
43
45
|
|
44
|
-
def
|
45
|
-
|
46
|
+
def wandb_enabled?
|
47
|
+
EasyML::Configuration.wandb_api_key.present?
|
48
|
+
end
|
46
49
|
|
47
|
-
|
50
|
+
def tune(&progress_block)
|
51
|
+
set_defaults!
|
52
|
+
@adapter = initialize_adapter
|
53
|
+
|
54
|
+
tuner_params = {
|
55
|
+
model: model,
|
56
|
+
config: {
|
57
|
+
n_trials: n_trials,
|
58
|
+
objective: objective,
|
59
|
+
hyperparameter_ranges: config,
|
60
|
+
},
|
61
|
+
direction: direction,
|
62
|
+
status: :running,
|
63
|
+
started_at: Time.current,
|
64
|
+
wandb_url: wandb_enabled? ? "https://wandb.ai/fundera/#{@project_name}" : nil,
|
65
|
+
}.compact
|
66
|
+
|
67
|
+
tuner_job = EasyML::TunerJob.create!(tuner_params)
|
68
|
+
@job = tuner_job
|
69
|
+
@study = Optuna::Study.new(direction: direction)
|
48
70
|
@results = []
|
71
|
+
model.evaluator = evaluator if evaluator.present?
|
49
72
|
model.task = task
|
73
|
+
|
74
|
+
model.dataset.refresh
|
50
75
|
x_true, y_true = model.dataset.test(split_ys: true)
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
adapter.
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
@
|
69
|
-
|
70
|
-
|
71
|
-
|
76
|
+
self.x_true = x_true
|
77
|
+
self.y_true = y_true
|
78
|
+
adapter.tune_started_at = tune_started_at
|
79
|
+
adapter.y_true = y_true
|
80
|
+
adapter.x_true = x_true
|
81
|
+
|
82
|
+
model.prepare_data unless model.batch_mode
|
83
|
+
model.prepare_callbacks(self)
|
84
|
+
|
85
|
+
n_trials.times do |run_number|
|
86
|
+
trial = @study.ask
|
87
|
+
puts "Running trial #{trial.number}"
|
88
|
+
@tuner_run = tuner_job.tuner_runs.new(
|
89
|
+
trial_number: trial.number,
|
90
|
+
status: :running,
|
91
|
+
)
|
92
|
+
|
93
|
+
self.current_run = @tuner_run
|
94
|
+
|
95
|
+
begin
|
96
|
+
run_metrics = tune_once(trial, x_true, y_true, adapter, &progress_block)
|
97
|
+
result = calculate_result(run_metrics)
|
98
|
+
@results.push(result)
|
99
|
+
|
100
|
+
params = {
|
101
|
+
hyperparameters: model.hyperparameters.to_h,
|
102
|
+
value: result,
|
103
|
+
status: :success,
|
104
|
+
}.compact
|
105
|
+
|
106
|
+
@tuner_run.update!(params)
|
107
|
+
@study.tell(trial, result)
|
108
|
+
rescue StandardError => e
|
109
|
+
@tuner_run.update!(status: :failed, hyperparameters: {})
|
110
|
+
puts "Optuna failed with: #{e.message}"
|
111
|
+
raise e
|
112
|
+
end
|
72
113
|
end
|
73
114
|
|
74
|
-
|
75
|
-
|
76
|
-
|
115
|
+
model.after_tuning
|
116
|
+
return nil if tuner_job.tuner_runs.all?(&:failed?)
|
117
|
+
|
118
|
+
best_run = tuner_job.best_run
|
119
|
+
tuner_job.update!(
|
120
|
+
metadata: adapter.metadata,
|
121
|
+
best_tuner_run_id: best_run&.id,
|
122
|
+
status: :success,
|
123
|
+
completed_at: Time.current,
|
124
|
+
)
|
125
|
+
|
126
|
+
best_run&.hyperparameters
|
127
|
+
rescue StandardError => e
|
128
|
+
tuner_job&.update!(status: :failed, completed_at: Time.current)
|
129
|
+
raise e
|
77
130
|
end
|
78
131
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
132
|
+
private
|
133
|
+
|
134
|
+
def calculate_result(run_metrics)
|
135
|
+
run_metrics.symbolize_keys!
|
136
|
+
|
137
|
+
if model.evaluator.present?
|
138
|
+
run_metrics[model.evaluator[:metric].to_sym]
|
139
|
+
else
|
140
|
+
run_metrics[objective.to_sym]
|
83
141
|
end
|
84
142
|
end
|
85
143
|
|
86
|
-
def tune_once(trial, x_true, y_true, adapter)
|
144
|
+
def tune_once(trial, x_true, y_true, adapter, &progress_block)
|
87
145
|
adapter.run_trial(trial) do |model|
|
88
|
-
|
146
|
+
model.fit(tuning: true, &progress_block)
|
147
|
+
y_pred = model.predict(x_true)
|
89
148
|
model.metrics = metrics
|
90
|
-
model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
149
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
150
|
+
puts metrics
|
151
|
+
metrics
|
91
152
|
end
|
92
153
|
end
|
93
154
|
|
@@ -98,7 +159,7 @@ module EasyML
|
|
98
159
|
end
|
99
160
|
raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
|
100
161
|
|
101
|
-
self.metrics = EasyML::
|
162
|
+
self.metrics = EasyML::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
|
102
163
|
end
|
103
164
|
end
|
104
165
|
end
|
data/lib/easy_ml/core.rb
CHANGED
@@ -0,0 +1,24 @@
|
|
1
|
+
module EasyML
|
2
|
+
module CoreExt
|
3
|
+
module Hash
|
4
|
+
def deep_compact
|
5
|
+
each_with_object({}) do |(key, value), result|
|
6
|
+
next if value.nil?
|
7
|
+
|
8
|
+
compacted = if value.is_a?(Hash)
|
9
|
+
value.deep_compact
|
10
|
+
elsif value.is_a?(Array)
|
11
|
+
value.map { |v| v.is_a?(Hash) ? v.deep_compact : v }.compact
|
12
|
+
else
|
13
|
+
value
|
14
|
+
end
|
15
|
+
|
16
|
+
result[key] = compacted unless compacted.blank?
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
# Extend Hash class with our custom method
|
24
|
+
Hash.include EasyML::CoreExt::Hash
|
@@ -1,9 +1,15 @@
|
|
1
1
|
require "pathname"
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
3
|
+
module EasyML
|
4
|
+
module CoreExt
|
5
|
+
module Pathname
|
6
|
+
def append(folder)
|
7
|
+
dir = cleanpath
|
8
|
+
dir = dir.join(folder) unless basename.to_s == folder
|
9
|
+
dir
|
10
|
+
end
|
11
|
+
end
|
8
12
|
end
|
9
13
|
end
|
14
|
+
|
15
|
+
Pathname.include EasyML::CoreExt::Pathname
|
@@ -0,0 +1,90 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module DateConverter
|
4
|
+
COMMON_DATE_FORMATS = [
|
5
|
+
"%Y-%m-%dT%H:%M:%S.%6N", # e.g., "2021-01-01T00:00:00.000000"
|
6
|
+
"%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
|
7
|
+
"%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
|
8
|
+
"%Y-%m-%d %H:%M:%S", # e.g., "2021-01-01 00:01:36"
|
9
|
+
"%Y-%m-%d %H:%M", # e.g., "2021-01-01 00:01"
|
10
|
+
"%Y-%m-%d", # e.g., "2021-01-01"
|
11
|
+
"%m/%d/%Y %H:%M:%S", # e.g., "01/01/2021 00:01:36"
|
12
|
+
"%m/%d/%Y", # e.g., "01/01/2021"
|
13
|
+
"%d-%m-%Y", # e.g., "01-01-2021"
|
14
|
+
"%d-%b-%Y %H:%M:%S", # e.g., "01-Jan-2021 00:01:36"
|
15
|
+
"%d-%b-%Y", # e.g., "01-Jan-2021"
|
16
|
+
"%b %d, %Y", # e.g., "Jan 01, 2021"
|
17
|
+
"%Y/%m/%d %H:%M:%S", # e.g., "2021/01/01 00:01:36"
|
18
|
+
"%Y/%m/%d", # e.g., "2021/01/01"
|
19
|
+
].freeze
|
20
|
+
|
21
|
+
FORMAT_MAPPINGS = {
|
22
|
+
ruby_to_polars: {
|
23
|
+
"%L" => "%3f", # milliseconds
|
24
|
+
"%6N" => "%6f", # microseconds
|
25
|
+
"%N" => "%9f", # nanoseconds
|
26
|
+
},
|
27
|
+
}.freeze
|
28
|
+
|
29
|
+
class << self
|
30
|
+
# Attempts to convert a string column to datetime if it appears to be a date
|
31
|
+
# @param df [Polars::DataFrame] The dataframe containing the series
|
32
|
+
# @param column [String] The name of the column to convert
|
33
|
+
# @return [Polars::DataFrame] The dataframe with converted column (if successful)
|
34
|
+
def maybe_convert_date(df, column = nil)
|
35
|
+
if column.nil?
|
36
|
+
series = df
|
37
|
+
column = series.name
|
38
|
+
df = Polars::DataFrame.new(series)
|
39
|
+
else
|
40
|
+
series = df[column]
|
41
|
+
end
|
42
|
+
return df if series.dtype.is_a?(Polars::Datetime)
|
43
|
+
return df unless series.dtype == Polars::Utf8
|
44
|
+
|
45
|
+
format = detect_polars_format(series)
|
46
|
+
return df unless format
|
47
|
+
|
48
|
+
df.with_column(
|
49
|
+
Polars.col(column.to_s).str.strptime(Polars::Datetime, format).alias(column.to_s)
|
50
|
+
)
|
51
|
+
end
|
52
|
+
|
53
|
+
private
|
54
|
+
|
55
|
+
def detect_polars_format(series)
|
56
|
+
return nil unless series.is_a?(Polars::Series)
|
57
|
+
|
58
|
+
sample = series.filter(series.is_not_null).head(100).to_a
|
59
|
+
ruby_format = detect_date_format(sample)
|
60
|
+
convert_format(:ruby_to_polars, ruby_format)
|
61
|
+
end
|
62
|
+
|
63
|
+
def detect_date_format(date_strings)
|
64
|
+
return nil if date_strings.empty?
|
65
|
+
|
66
|
+
sample = date_strings.compact.sample([100, date_strings.length].min)
|
67
|
+
|
68
|
+
COMMON_DATE_FORMATS.detect do |format|
|
69
|
+
sample.all? do |date_str|
|
70
|
+
DateTime.strptime(date_str, format)
|
71
|
+
true
|
72
|
+
rescue StandardError
|
73
|
+
false
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
def convert_format(conversion, format)
|
79
|
+
return nil if format.nil?
|
80
|
+
|
81
|
+
result = format.dup
|
82
|
+
FORMAT_MAPPINGS[conversion].each do |from, to|
|
83
|
+
result = result.gsub(from, to)
|
84
|
+
end
|
85
|
+
result
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module FilterExtensions
|
4
|
+
def is_primary_key_filter?(primary_key)
|
5
|
+
return false unless primary_key
|
6
|
+
primary_key = [primary_key] unless primary_key.is_a?(Array)
|
7
|
+
# Filter expressions in Polars are represented as strings like:
|
8
|
+
# [([(col("LOAN_APP_ID")) > (dyn int: 4)]) & ([(col("LOAN_APP_ID")) < (dyn int: 16)])]
|
9
|
+
expr_str = to_s
|
10
|
+
return false unless expr_str.include?(primary_key.first)
|
11
|
+
|
12
|
+
# Check for common primary key operations
|
13
|
+
primary_key_ops = [">", "<", ">=", "<=", "=", "eq", "gt", "lt", "ge", "le"]
|
14
|
+
primary_key_ops.any? { |op| expr_str.include?(op) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def extract_primary_key_values
|
18
|
+
expr_str = to_s
|
19
|
+
# Extract numeric values from the expression
|
20
|
+
# This will match both integers and floats
|
21
|
+
values = expr_str.scan(/(?:dyn int|float): (-?\d+(?:\.\d+)?)/).flatten.map(&:to_f)
|
22
|
+
values.uniq
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
# Extend Polars classes with our filter functionality
|
29
|
+
[Polars::Expr].each do |klass|
|
30
|
+
klass.include(EasyML::Data::FilterExtensions)
|
31
|
+
end
|
@@ -0,0 +1,126 @@
|
|
1
|
+
require_relative "date_converter"
|
2
|
+
|
3
|
+
module EasyML
|
4
|
+
module Data
|
5
|
+
module PolarsColumn
|
6
|
+
TYPE_MAP = {
|
7
|
+
float: Polars::Float64,
|
8
|
+
integer: Polars::Int64,
|
9
|
+
boolean: Polars::Boolean,
|
10
|
+
datetime: Polars::Datetime,
|
11
|
+
string: Polars::String,
|
12
|
+
text: Polars::String,
|
13
|
+
categorical: Polars::Categorical
|
14
|
+
}
|
15
|
+
POLARS_MAP = TYPE_MAP.invert.stringify_keys
|
16
|
+
class << self
|
17
|
+
def polars_to_sym(polars_type)
|
18
|
+
POLARS_MAP.dig(polars_type.class.to_s)
|
19
|
+
end
|
20
|
+
|
21
|
+
def sym_to_polars(symbol)
|
22
|
+
TYPE_MAP.dig(symbol)
|
23
|
+
end
|
24
|
+
|
25
|
+
# Determines the semantic type of a field based on its data
|
26
|
+
# @param series [Polars::Series] The series to analyze
|
27
|
+
# @return [Symbol] One of :numeric, :datetime, :categorical, or :text
|
28
|
+
def determine_type(series, polars_type = false)
|
29
|
+
dtype = series.dtype
|
30
|
+
|
31
|
+
if dtype.is_a?(Polars::Utf8)
|
32
|
+
string_type = determine_string_type(series)
|
33
|
+
if string_type == :datetime
|
34
|
+
date = EasyML::Data::DateConverter.maybe_convert_date(series)
|
35
|
+
return polars_type ? date[date.columns.first].dtype : :datetime
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
type_name = case dtype
|
40
|
+
when Polars::Float64
|
41
|
+
:float
|
42
|
+
when Polars::Int64
|
43
|
+
:integer
|
44
|
+
when Polars::Datetime
|
45
|
+
:datetime
|
46
|
+
when Polars::Boolean
|
47
|
+
:boolean
|
48
|
+
when Polars::Utf8
|
49
|
+
determine_string_type(series)
|
50
|
+
else
|
51
|
+
:categorical
|
52
|
+
end
|
53
|
+
|
54
|
+
polars_type ? sym_to_polars(type_name) : type_name
|
55
|
+
end
|
56
|
+
|
57
|
+
# Determines if a string field is a date, text, or categorical
|
58
|
+
# @param series [Polars::Series] The string series to analyze
|
59
|
+
# @return [Symbol] One of :datetime, :text, or :categorical
|
60
|
+
def determine_string_type(series)
|
61
|
+
if EasyML::Data::DateConverter.maybe_convert_date(Polars::DataFrame.new({ temp: series }),
|
62
|
+
:temp)[:temp].dtype.is_a?(Polars::Datetime)
|
63
|
+
:datetime
|
64
|
+
else
|
65
|
+
categorical_or_text?(series)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
# Determines if a string field is categorical or free text
|
70
|
+
# @param series [Polars::Series] The string series to analyze
|
71
|
+
# @return [Symbol] Either :categorical or :text
|
72
|
+
def categorical_or_text?(series)
|
73
|
+
return :categorical if series.null_count == series.len
|
74
|
+
|
75
|
+
# Get non-null count for percentage calculations
|
76
|
+
non_null_count = series.len - series.null_count
|
77
|
+
return :categorical if non_null_count == 0
|
78
|
+
|
79
|
+
# Get value counts as percentages
|
80
|
+
value_counts = series.value_counts(parallel: true)
|
81
|
+
percentages = value_counts.with_column(
|
82
|
+
(value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
|
83
|
+
)
|
84
|
+
|
85
|
+
# Check if any category represents more than 10% of the data
|
86
|
+
max_percentage = percentages["percentage"].max
|
87
|
+
return :text if max_percentage < 10.0
|
88
|
+
|
89
|
+
# Calculate average percentage per category
|
90
|
+
avg_percentage = 100.0 / series.n_unique
|
91
|
+
|
92
|
+
# If average category represents less than 1% of data, it's likely text
|
93
|
+
avg_percentage < 1.0 ? :text : :categorical
|
94
|
+
end
|
95
|
+
|
96
|
+
# Returns whether the field type is numeric
|
97
|
+
# @param field_type [Symbol] The field type to check
|
98
|
+
# @return [Boolean]
|
99
|
+
def numeric?(field_type)
|
100
|
+
field_type == :numeric
|
101
|
+
end
|
102
|
+
|
103
|
+
# Returns whether the field type is categorical
|
104
|
+
# @param field_type [Symbol] The field type to check
|
105
|
+
# @return [Boolean]
|
106
|
+
def categorical?(field_type)
|
107
|
+
field_type == :categorical
|
108
|
+
end
|
109
|
+
|
110
|
+
# Returns whether the field type is datetime
|
111
|
+
# @param field_type [Symbol] The field type to check
|
112
|
+
# @return [Boolean]
|
113
|
+
def datetime?(field_type)
|
114
|
+
field_type == :datetime
|
115
|
+
end
|
116
|
+
|
117
|
+
# Returns whether the field type is text
|
118
|
+
# @param field_type [Symbol] The field type to check
|
119
|
+
# @return [Boolean]
|
120
|
+
def text?(field_type)
|
121
|
+
field_type == :text
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|