easy_ml 0.1.4 → 0.2.0.pre.rc1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -5,44 +5,23 @@ module EasyML
|
|
5
5
|
class Tuner
|
6
6
|
module Adapters
|
7
7
|
class XGBoostAdapter < BaseAdapter
|
8
|
-
include GlueGun::DSL
|
9
|
-
|
10
8
|
def defaults
|
11
9
|
{
|
12
10
|
learning_rate: {
|
13
11
|
min: 0.001,
|
14
12
|
max: 0.1,
|
15
|
-
log: true
|
13
|
+
log: true,
|
16
14
|
},
|
17
15
|
n_estimators: {
|
18
16
|
min: 100,
|
19
|
-
max: 1_000
|
17
|
+
max: 1_000,
|
20
18
|
},
|
21
19
|
max_depth: {
|
22
20
|
min: 2,
|
23
|
-
max: 20
|
24
|
-
}
|
21
|
+
max: 20,
|
22
|
+
},
|
25
23
|
}
|
26
24
|
end
|
27
|
-
|
28
|
-
def configure_callbacks
|
29
|
-
model.customize_callbacks do |callbacks|
|
30
|
-
return unless callbacks.present?
|
31
|
-
|
32
|
-
wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
|
33
|
-
return unless wandb_callback.present?
|
34
|
-
|
35
|
-
wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
36
|
-
wandb_callback.custom_loggers = [
|
37
|
-
lambda do |booster, _epoch, _hist|
|
38
|
-
dtrain = model.send(:preprocess, x_true, y_true)
|
39
|
-
y_pred = booster.predict(dtrain)
|
40
|
-
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
41
|
-
Wandb.log(metrics)
|
42
|
-
end
|
43
|
-
]
|
44
|
-
end
|
45
|
-
end
|
46
25
|
end
|
47
26
|
end
|
48
27
|
end
|
data/lib/easy_ml/core/tuner.rb
CHANGED
@@ -4,34 +4,36 @@ require_relative "tuner/adapters"
|
|
4
4
|
module EasyML
|
5
5
|
module Core
|
6
6
|
class Tuner
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
opt.bind_attribute :tune_started_at
|
27
|
-
opt.bind_attribute :y_true
|
28
|
-
end
|
7
|
+
attr_accessor :model, :dataset, :project_name, :task, :config,
|
8
|
+
:metrics, :objective, :n_trials, :direction, :evaluator,
|
9
|
+
:study, :results, :adapter, :tune_started_at, :x_true, :y_true,
|
10
|
+
:project_name, :job, :current_run
|
11
|
+
|
12
|
+
def initialize(options = {})
|
13
|
+
@model = options[:model]
|
14
|
+
@dataset = options[:dataset]
|
15
|
+
@project_name = options[:project_name]
|
16
|
+
@task = options[:task]
|
17
|
+
@config = options[:config] || {}
|
18
|
+
@metrics = options[:metrics]
|
19
|
+
@objective = options[:objective]
|
20
|
+
@n_trials = options[:n_trials] || 100
|
21
|
+
@direction = EasyML::Core::ModelEvaluator.get(objective).new.direction
|
22
|
+
@evaluator = options[:evaluator]
|
23
|
+
@tune_started_at = EasyML::Support::UTC.now
|
24
|
+
@project_name = "#{@model.name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
|
25
|
+
end
|
29
26
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
27
|
+
def initialize_adapter
|
28
|
+
case model&.model_type
|
29
|
+
when "xgboost"
|
30
|
+
Adapters::XGBoostAdapter.new(
|
31
|
+
model: model,
|
32
|
+
config: config,
|
33
|
+
project_name: project_name,
|
34
|
+
tune_started_at: nil, # This will be set during tune
|
35
|
+
y_true: nil, # This will be set during tune
|
36
|
+
)
|
35
37
|
end
|
36
38
|
end
|
37
39
|
|
@@ -41,53 +43,112 @@ module EasyML
|
|
41
43
|
raise "Trial failed: Stopping optimization."
|
42
44
|
end
|
43
45
|
|
44
|
-
def
|
45
|
-
|
46
|
+
def wandb_enabled?
|
47
|
+
EasyML::Configuration.wandb_api_key.present?
|
48
|
+
end
|
46
49
|
|
47
|
-
|
50
|
+
def tune(&progress_block)
|
51
|
+
set_defaults!
|
52
|
+
@adapter = initialize_adapter
|
53
|
+
|
54
|
+
tuner_params = {
|
55
|
+
model: model,
|
56
|
+
config: {
|
57
|
+
n_trials: n_trials,
|
58
|
+
objective: objective,
|
59
|
+
hyperparameter_ranges: config,
|
60
|
+
},
|
61
|
+
direction: direction,
|
62
|
+
status: :running,
|
63
|
+
started_at: Time.current,
|
64
|
+
wandb_url: wandb_enabled? ? "https://wandb.ai/fundera/#{@project_name}" : nil,
|
65
|
+
}.compact
|
66
|
+
|
67
|
+
tuner_job = EasyML::TunerJob.create!(tuner_params)
|
68
|
+
@job = tuner_job
|
69
|
+
@study = Optuna::Study.new(direction: direction)
|
48
70
|
@results = []
|
71
|
+
model.evaluator = evaluator if evaluator.present?
|
49
72
|
model.task = task
|
73
|
+
|
74
|
+
model.dataset.refresh
|
50
75
|
x_true, y_true = model.dataset.test(split_ys: true)
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
adapter.
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
@
|
69
|
-
|
70
|
-
|
71
|
-
|
76
|
+
self.x_true = x_true
|
77
|
+
self.y_true = y_true
|
78
|
+
adapter.tune_started_at = tune_started_at
|
79
|
+
adapter.y_true = y_true
|
80
|
+
adapter.x_true = x_true
|
81
|
+
|
82
|
+
model.prepare_data unless model.batch_mode
|
83
|
+
model.prepare_callbacks(self)
|
84
|
+
|
85
|
+
n_trials.times do |run_number|
|
86
|
+
trial = @study.ask
|
87
|
+
puts "Running trial #{trial.number}"
|
88
|
+
@tuner_run = tuner_job.tuner_runs.new(
|
89
|
+
trial_number: trial.number,
|
90
|
+
status: :running,
|
91
|
+
)
|
92
|
+
|
93
|
+
self.current_run = @tuner_run
|
94
|
+
|
95
|
+
begin
|
96
|
+
run_metrics = tune_once(trial, x_true, y_true, adapter, &progress_block)
|
97
|
+
result = calculate_result(run_metrics)
|
98
|
+
@results.push(result)
|
99
|
+
|
100
|
+
params = {
|
101
|
+
hyperparameters: model.hyperparameters.to_h,
|
102
|
+
value: result,
|
103
|
+
status: :success,
|
104
|
+
}.compact
|
105
|
+
|
106
|
+
@tuner_run.update!(params)
|
107
|
+
@study.tell(trial, result)
|
108
|
+
rescue StandardError => e
|
109
|
+
@tuner_run.update!(status: :failed, hyperparameters: {})
|
110
|
+
puts "Optuna failed with: #{e.message}"
|
111
|
+
raise e
|
112
|
+
end
|
72
113
|
end
|
73
114
|
|
74
|
-
|
75
|
-
|
76
|
-
|
115
|
+
model.after_tuning
|
116
|
+
return nil if tuner_job.tuner_runs.all?(&:failed?)
|
117
|
+
|
118
|
+
best_run = tuner_job.best_run
|
119
|
+
tuner_job.update!(
|
120
|
+
metadata: adapter.metadata,
|
121
|
+
best_tuner_run_id: best_run&.id,
|
122
|
+
status: :success,
|
123
|
+
completed_at: Time.current,
|
124
|
+
)
|
125
|
+
|
126
|
+
best_run&.hyperparameters
|
127
|
+
rescue StandardError => e
|
128
|
+
tuner_job&.update!(status: :failed, completed_at: Time.current)
|
129
|
+
raise e
|
77
130
|
end
|
78
131
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
132
|
+
private
|
133
|
+
|
134
|
+
def calculate_result(run_metrics)
|
135
|
+
run_metrics.symbolize_keys!
|
136
|
+
|
137
|
+
if model.evaluator.present?
|
138
|
+
run_metrics[model.evaluator[:metric].to_sym]
|
139
|
+
else
|
140
|
+
run_metrics[objective.to_sym]
|
83
141
|
end
|
84
142
|
end
|
85
143
|
|
86
|
-
def tune_once(trial, x_true, y_true, adapter)
|
144
|
+
def tune_once(trial, x_true, y_true, adapter, &progress_block)
|
87
145
|
adapter.run_trial(trial) do |model|
|
88
|
-
|
146
|
+
model.fit(tuning: true, &progress_block)
|
147
|
+
y_pred = model.predict(x_true)
|
89
148
|
model.metrics = metrics
|
90
|
-
model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
149
|
+
metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
150
|
+
puts metrics
|
151
|
+
metrics
|
91
152
|
end
|
92
153
|
end
|
93
154
|
|
@@ -98,7 +159,7 @@ module EasyML
|
|
98
159
|
end
|
99
160
|
raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
|
100
161
|
|
101
|
-
self.metrics = EasyML::
|
162
|
+
self.metrics = EasyML::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
|
102
163
|
end
|
103
164
|
end
|
104
165
|
end
|
data/lib/easy_ml/core.rb
CHANGED
@@ -0,0 +1,24 @@
|
|
1
|
+
module EasyML
|
2
|
+
module CoreExt
|
3
|
+
module Hash
|
4
|
+
def deep_compact
|
5
|
+
each_with_object({}) do |(key, value), result|
|
6
|
+
next if value.nil?
|
7
|
+
|
8
|
+
compacted = if value.is_a?(Hash)
|
9
|
+
value.deep_compact
|
10
|
+
elsif value.is_a?(Array)
|
11
|
+
value.map { |v| v.is_a?(Hash) ? v.deep_compact : v }.compact
|
12
|
+
else
|
13
|
+
value
|
14
|
+
end
|
15
|
+
|
16
|
+
result[key] = compacted unless compacted.blank?
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
# Extend Hash class with our custom method
|
24
|
+
Hash.include EasyML::CoreExt::Hash
|
@@ -1,9 +1,15 @@
|
|
1
1
|
require "pathname"
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
3
|
+
module EasyML
|
4
|
+
module CoreExt
|
5
|
+
module Pathname
|
6
|
+
def append(folder)
|
7
|
+
dir = cleanpath
|
8
|
+
dir = dir.join(folder) unless basename.to_s == folder
|
9
|
+
dir
|
10
|
+
end
|
11
|
+
end
|
8
12
|
end
|
9
13
|
end
|
14
|
+
|
15
|
+
Pathname.include EasyML::CoreExt::Pathname
|
@@ -0,0 +1,90 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module DateConverter
|
4
|
+
COMMON_DATE_FORMATS = [
|
5
|
+
"%Y-%m-%dT%H:%M:%S.%6N", # e.g., "2021-01-01T00:00:00.000000"
|
6
|
+
"%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
|
7
|
+
"%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
|
8
|
+
"%Y-%m-%d %H:%M:%S", # e.g., "2021-01-01 00:01:36"
|
9
|
+
"%Y-%m-%d %H:%M", # e.g., "2021-01-01 00:01"
|
10
|
+
"%Y-%m-%d", # e.g., "2021-01-01"
|
11
|
+
"%m/%d/%Y %H:%M:%S", # e.g., "01/01/2021 00:01:36"
|
12
|
+
"%m/%d/%Y", # e.g., "01/01/2021"
|
13
|
+
"%d-%m-%Y", # e.g., "01-01-2021"
|
14
|
+
"%d-%b-%Y %H:%M:%S", # e.g., "01-Jan-2021 00:01:36"
|
15
|
+
"%d-%b-%Y", # e.g., "01-Jan-2021"
|
16
|
+
"%b %d, %Y", # e.g., "Jan 01, 2021"
|
17
|
+
"%Y/%m/%d %H:%M:%S", # e.g., "2021/01/01 00:01:36"
|
18
|
+
"%Y/%m/%d", # e.g., "2021/01/01"
|
19
|
+
].freeze
|
20
|
+
|
21
|
+
FORMAT_MAPPINGS = {
|
22
|
+
ruby_to_polars: {
|
23
|
+
"%L" => "%3f", # milliseconds
|
24
|
+
"%6N" => "%6f", # microseconds
|
25
|
+
"%N" => "%9f", # nanoseconds
|
26
|
+
},
|
27
|
+
}.freeze
|
28
|
+
|
29
|
+
class << self
|
30
|
+
# Attempts to convert a string column to datetime if it appears to be a date
|
31
|
+
# @param df [Polars::DataFrame] The dataframe containing the series
|
32
|
+
# @param column [String] The name of the column to convert
|
33
|
+
# @return [Polars::DataFrame] The dataframe with converted column (if successful)
|
34
|
+
def maybe_convert_date(df, column = nil)
|
35
|
+
if column.nil?
|
36
|
+
series = df
|
37
|
+
column = series.name
|
38
|
+
df = Polars::DataFrame.new(series)
|
39
|
+
else
|
40
|
+
series = df[column]
|
41
|
+
end
|
42
|
+
return df if series.dtype.is_a?(Polars::Datetime)
|
43
|
+
return df unless series.dtype == Polars::Utf8
|
44
|
+
|
45
|
+
format = detect_polars_format(series)
|
46
|
+
return df unless format
|
47
|
+
|
48
|
+
df.with_column(
|
49
|
+
Polars.col(column.to_s).str.strptime(Polars::Datetime, format).alias(column.to_s)
|
50
|
+
)
|
51
|
+
end
|
52
|
+
|
53
|
+
private
|
54
|
+
|
55
|
+
def detect_polars_format(series)
|
56
|
+
return nil unless series.is_a?(Polars::Series)
|
57
|
+
|
58
|
+
sample = series.filter(series.is_not_null).head(100).to_a
|
59
|
+
ruby_format = detect_date_format(sample)
|
60
|
+
convert_format(:ruby_to_polars, ruby_format)
|
61
|
+
end
|
62
|
+
|
63
|
+
def detect_date_format(date_strings)
|
64
|
+
return nil if date_strings.empty?
|
65
|
+
|
66
|
+
sample = date_strings.compact.sample([100, date_strings.length].min)
|
67
|
+
|
68
|
+
COMMON_DATE_FORMATS.detect do |format|
|
69
|
+
sample.all? do |date_str|
|
70
|
+
DateTime.strptime(date_str, format)
|
71
|
+
true
|
72
|
+
rescue StandardError
|
73
|
+
false
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
def convert_format(conversion, format)
|
79
|
+
return nil if format.nil?
|
80
|
+
|
81
|
+
result = format.dup
|
82
|
+
FORMAT_MAPPINGS[conversion].each do |from, to|
|
83
|
+
result = result.gsub(from, to)
|
84
|
+
end
|
85
|
+
result
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module FilterExtensions
|
4
|
+
def is_primary_key_filter?(primary_key)
|
5
|
+
return false unless primary_key
|
6
|
+
primary_key = [primary_key] unless primary_key.is_a?(Array)
|
7
|
+
# Filter expressions in Polars are represented as strings like:
|
8
|
+
# [([(col("LOAN_APP_ID")) > (dyn int: 4)]) & ([(col("LOAN_APP_ID")) < (dyn int: 16)])]
|
9
|
+
expr_str = to_s
|
10
|
+
return false unless expr_str.include?(primary_key.first)
|
11
|
+
|
12
|
+
# Check for common primary key operations
|
13
|
+
primary_key_ops = [">", "<", ">=", "<=", "=", "eq", "gt", "lt", "ge", "le"]
|
14
|
+
primary_key_ops.any? { |op| expr_str.include?(op) }
|
15
|
+
end
|
16
|
+
|
17
|
+
def extract_primary_key_values
|
18
|
+
expr_str = to_s
|
19
|
+
# Extract numeric values from the expression
|
20
|
+
# This will match both integers and floats
|
21
|
+
values = expr_str.scan(/(?:dyn int|float): (-?\d+(?:\.\d+)?)/).flatten.map(&:to_f)
|
22
|
+
values.uniq
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
# Extend Polars classes with our filter functionality
|
29
|
+
[Polars::Expr].each do |klass|
|
30
|
+
klass.include(EasyML::Data::FilterExtensions)
|
31
|
+
end
|
@@ -0,0 +1,126 @@
|
|
1
|
+
require_relative "date_converter"
|
2
|
+
|
3
|
+
module EasyML
|
4
|
+
module Data
|
5
|
+
module PolarsColumn
|
6
|
+
TYPE_MAP = {
|
7
|
+
float: Polars::Float64,
|
8
|
+
integer: Polars::Int64,
|
9
|
+
boolean: Polars::Boolean,
|
10
|
+
datetime: Polars::Datetime,
|
11
|
+
string: Polars::String,
|
12
|
+
text: Polars::String,
|
13
|
+
categorical: Polars::Categorical
|
14
|
+
}
|
15
|
+
POLARS_MAP = TYPE_MAP.invert.stringify_keys
|
16
|
+
class << self
|
17
|
+
def polars_to_sym(polars_type)
|
18
|
+
POLARS_MAP.dig(polars_type.class.to_s)
|
19
|
+
end
|
20
|
+
|
21
|
+
def sym_to_polars(symbol)
|
22
|
+
TYPE_MAP.dig(symbol)
|
23
|
+
end
|
24
|
+
|
25
|
+
# Determines the semantic type of a field based on its data
|
26
|
+
# @param series [Polars::Series] The series to analyze
|
27
|
+
# @return [Symbol] One of :numeric, :datetime, :categorical, or :text
|
28
|
+
def determine_type(series, polars_type = false)
|
29
|
+
dtype = series.dtype
|
30
|
+
|
31
|
+
if dtype.is_a?(Polars::Utf8)
|
32
|
+
string_type = determine_string_type(series)
|
33
|
+
if string_type == :datetime
|
34
|
+
date = EasyML::Data::DateConverter.maybe_convert_date(series)
|
35
|
+
return polars_type ? date[date.columns.first].dtype : :datetime
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
type_name = case dtype
|
40
|
+
when Polars::Float64
|
41
|
+
:float
|
42
|
+
when Polars::Int64
|
43
|
+
:integer
|
44
|
+
when Polars::Datetime
|
45
|
+
:datetime
|
46
|
+
when Polars::Boolean
|
47
|
+
:boolean
|
48
|
+
when Polars::Utf8
|
49
|
+
determine_string_type(series)
|
50
|
+
else
|
51
|
+
:categorical
|
52
|
+
end
|
53
|
+
|
54
|
+
polars_type ? sym_to_polars(type_name) : type_name
|
55
|
+
end
|
56
|
+
|
57
|
+
# Determines if a string field is a date, text, or categorical
|
58
|
+
# @param series [Polars::Series] The string series to analyze
|
59
|
+
# @return [Symbol] One of :datetime, :text, or :categorical
|
60
|
+
def determine_string_type(series)
|
61
|
+
if EasyML::Data::DateConverter.maybe_convert_date(Polars::DataFrame.new({ temp: series }),
|
62
|
+
:temp)[:temp].dtype.is_a?(Polars::Datetime)
|
63
|
+
:datetime
|
64
|
+
else
|
65
|
+
categorical_or_text?(series)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
# Determines if a string field is categorical or free text
|
70
|
+
# @param series [Polars::Series] The string series to analyze
|
71
|
+
# @return [Symbol] Either :categorical or :text
|
72
|
+
def categorical_or_text?(series)
|
73
|
+
return :categorical if series.null_count == series.len
|
74
|
+
|
75
|
+
# Get non-null count for percentage calculations
|
76
|
+
non_null_count = series.len - series.null_count
|
77
|
+
return :categorical if non_null_count == 0
|
78
|
+
|
79
|
+
# Get value counts as percentages
|
80
|
+
value_counts = series.value_counts(parallel: true)
|
81
|
+
percentages = value_counts.with_column(
|
82
|
+
(value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
|
83
|
+
)
|
84
|
+
|
85
|
+
# Check if any category represents more than 10% of the data
|
86
|
+
max_percentage = percentages["percentage"].max
|
87
|
+
return :text if max_percentage < 10.0
|
88
|
+
|
89
|
+
# Calculate average percentage per category
|
90
|
+
avg_percentage = 100.0 / series.n_unique
|
91
|
+
|
92
|
+
# If average category represents less than 1% of data, it's likely text
|
93
|
+
avg_percentage < 1.0 ? :text : :categorical
|
94
|
+
end
|
95
|
+
|
96
|
+
# Returns whether the field type is numeric
|
97
|
+
# @param field_type [Symbol] The field type to check
|
98
|
+
# @return [Boolean]
|
99
|
+
def numeric?(field_type)
|
100
|
+
field_type == :numeric
|
101
|
+
end
|
102
|
+
|
103
|
+
# Returns whether the field type is categorical
|
104
|
+
# @param field_type [Symbol] The field type to check
|
105
|
+
# @return [Boolean]
|
106
|
+
def categorical?(field_type)
|
107
|
+
field_type == :categorical
|
108
|
+
end
|
109
|
+
|
110
|
+
# Returns whether the field type is datetime
|
111
|
+
# @param field_type [Symbol] The field type to check
|
112
|
+
# @return [Boolean]
|
113
|
+
def datetime?(field_type)
|
114
|
+
field_type == :datetime
|
115
|
+
end
|
116
|
+
|
117
|
+
# Returns whether the field type is text
|
118
|
+
# @param field_type [Symbol] The field type to check
|
119
|
+
# @return [Boolean]
|
120
|
+
def text?(field_type)
|
121
|
+
field_type == :text
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|