easy_ml 0.1.4 → 0.2.0.pre.rc1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +234 -26
- data/Rakefile +45 -0
- data/app/controllers/easy_ml/application_controller.rb +67 -0
- data/app/controllers/easy_ml/columns_controller.rb +38 -0
- data/app/controllers/easy_ml/datasets_controller.rb +156 -0
- data/app/controllers/easy_ml/datasources_controller.rb +88 -0
- data/app/controllers/easy_ml/deploys_controller.rb +20 -0
- data/app/controllers/easy_ml/models_controller.rb +151 -0
- data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
- data/app/controllers/easy_ml/settings_controller.rb +59 -0
- data/app/frontend/components/AlertProvider.tsx +108 -0
- data/app/frontend/components/DatasetPreview.tsx +161 -0
- data/app/frontend/components/EmptyState.tsx +28 -0
- data/app/frontend/components/ModelCard.tsx +255 -0
- data/app/frontend/components/ModelDetails.tsx +334 -0
- data/app/frontend/components/ModelForm.tsx +384 -0
- data/app/frontend/components/Navigation.tsx +300 -0
- data/app/frontend/components/Pagination.tsx +72 -0
- data/app/frontend/components/Popover.tsx +55 -0
- data/app/frontend/components/PredictionStream.tsx +105 -0
- data/app/frontend/components/ScheduleModal.tsx +726 -0
- data/app/frontend/components/SearchInput.tsx +23 -0
- data/app/frontend/components/SearchableSelect.tsx +132 -0
- data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
- data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
- data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
- data/app/frontend/components/dataset/ColumnList.tsx +101 -0
- data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
- data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
- data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
- data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
- data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
- data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
- data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
- data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
- data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
- data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
- data/app/frontend/components/dataset/splitters/constants.ts +77 -0
- data/app/frontend/components/dataset/splitters/types.ts +168 -0
- data/app/frontend/components/dataset/splitters/utils.ts +53 -0
- data/app/frontend/components/features/CodeEditor.tsx +46 -0
- data/app/frontend/components/features/DataPreview.tsx +150 -0
- data/app/frontend/components/features/FeatureCard.tsx +88 -0
- data/app/frontend/components/features/FeatureForm.tsx +235 -0
- data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
- data/app/frontend/components/settings/PluginSettings.tsx +81 -0
- data/app/frontend/components/ui/badge.tsx +44 -0
- data/app/frontend/components/ui/collapsible.tsx +9 -0
- data/app/frontend/components/ui/scroll-area.tsx +46 -0
- data/app/frontend/components/ui/separator.tsx +29 -0
- data/app/frontend/entrypoints/App.tsx +40 -0
- data/app/frontend/entrypoints/Application.tsx +24 -0
- data/app/frontend/hooks/useAutosave.ts +61 -0
- data/app/frontend/layouts/Layout.tsx +38 -0
- data/app/frontend/lib/utils.ts +6 -0
- data/app/frontend/mockData.ts +272 -0
- data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
- data/app/frontend/pages/DatasetsPage.tsx +261 -0
- data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
- data/app/frontend/pages/DatasourcesPage.tsx +261 -0
- data/app/frontend/pages/EditModelPage.tsx +45 -0
- data/app/frontend/pages/EditTransformationPage.tsx +56 -0
- data/app/frontend/pages/ModelsPage.tsx +115 -0
- data/app/frontend/pages/NewDatasetPage.tsx +366 -0
- data/app/frontend/pages/NewModelPage.tsx +45 -0
- data/app/frontend/pages/NewTransformationPage.tsx +43 -0
- data/app/frontend/pages/SettingsPage.tsx +272 -0
- data/app/frontend/pages/ShowModelPage.tsx +30 -0
- data/app/frontend/pages/TransformationsPage.tsx +95 -0
- data/app/frontend/styles/application.css +100 -0
- data/app/frontend/types/dataset.ts +146 -0
- data/app/frontend/types/datasource.ts +33 -0
- data/app/frontend/types/preprocessing.ts +1 -0
- data/app/frontend/types.ts +113 -0
- data/app/helpers/easy_ml/application_helper.rb +10 -0
- data/app/jobs/easy_ml/application_job.rb +21 -0
- data/app/jobs/easy_ml/batch_job.rb +46 -0
- data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
- data/app/jobs/easy_ml/deploy_job.rb +13 -0
- data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
- data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
- data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
- data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
- data/app/jobs/easy_ml/training_job.rb +62 -0
- data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
- data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
- data/app/models/easy_ml/cleaner.rb +82 -0
- data/app/models/easy_ml/column.rb +124 -0
- data/app/models/easy_ml/column_history.rb +30 -0
- data/app/models/easy_ml/column_list.rb +122 -0
- data/app/models/easy_ml/concerns/configurable.rb +61 -0
- data/app/models/easy_ml/concerns/versionable.rb +19 -0
- data/app/models/easy_ml/dataset.rb +767 -0
- data/app/models/easy_ml/dataset_history.rb +56 -0
- data/app/models/easy_ml/datasource.rb +182 -0
- data/app/models/easy_ml/datasource_history.rb +24 -0
- data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
- data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
- data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
- data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
- data/app/models/easy_ml/deploy.rb +114 -0
- data/app/models/easy_ml/event.rb +79 -0
- data/app/models/easy_ml/feature.rb +437 -0
- data/app/models/easy_ml/feature_history.rb +38 -0
- data/app/models/easy_ml/model.rb +575 -41
- data/app/models/easy_ml/model_file.rb +133 -0
- data/app/models/easy_ml/model_file_history.rb +24 -0
- data/app/models/easy_ml/model_history.rb +51 -0
- data/app/models/easy_ml/models/base_model.rb +58 -0
- data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
- data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
- data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
- data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
- data/app/models/easy_ml/models/xgboost.rb +544 -5
- data/app/models/easy_ml/prediction.rb +44 -0
- data/app/models/easy_ml/retraining_job.rb +278 -0
- data/app/models/easy_ml/retraining_run.rb +184 -0
- data/app/models/easy_ml/settings.rb +37 -0
- data/app/models/easy_ml/splitter.rb +90 -0
- data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
- data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
- data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
- data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
- data/app/models/easy_ml/tuner_job.rb +56 -0
- data/app/models/easy_ml/tuner_run.rb +31 -0
- data/app/models/splitter_history.rb +6 -0
- data/app/serializers/easy_ml/column_serializer.rb +27 -0
- data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
- data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
- data/app/serializers/easy_ml/feature_serializer.rb +27 -0
- data/app/serializers/easy_ml/model_serializer.rb +90 -0
- data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
- data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
- data/app/serializers/easy_ml/settings_serializer.rb +9 -0
- data/app/views/layouts/easy_ml/application.html.erb +15 -0
- data/config/initializers/resque.rb +3 -0
- data/config/resque-pool.yml +6 -0
- data/config/routes.rb +39 -0
- data/config/spring.rb +1 -0
- data/config/vite.json +15 -0
- data/lib/easy_ml/configuration.rb +64 -0
- data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
- data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
- data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
- data/lib/easy_ml/core/model_evaluator.rb +161 -89
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
- data/lib/easy_ml/core/tuner.rb +123 -62
- data/lib/easy_ml/core.rb +0 -3
- data/lib/easy_ml/core_ext/hash.rb +24 -0
- data/lib/easy_ml/core_ext/pathname.rb +11 -5
- data/lib/easy_ml/data/date_converter.rb +90 -0
- data/lib/easy_ml/data/filter_extensions.rb +31 -0
- data/lib/easy_ml/data/polars_column.rb +126 -0
- data/lib/easy_ml/data/polars_reader.rb +297 -0
- data/lib/easy_ml/data/preprocessor.rb +280 -142
- data/lib/easy_ml/data/simple_imputer.rb +255 -0
- data/lib/easy_ml/data/splits/file_split.rb +252 -0
- data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
- data/lib/easy_ml/data/splits/split.rb +95 -0
- data/lib/easy_ml/data/splits.rb +9 -0
- data/lib/easy_ml/data/statistics_learner.rb +93 -0
- data/lib/easy_ml/data/synced_directory.rb +341 -0
- data/lib/easy_ml/data.rb +6 -2
- data/lib/easy_ml/engine.rb +105 -6
- data/lib/easy_ml/feature_store.rb +227 -0
- data/lib/easy_ml/features.rb +61 -0
- data/lib/easy_ml/initializers/inflections.rb +17 -3
- data/lib/easy_ml/logging.rb +2 -2
- data/lib/easy_ml/predict.rb +74 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
- data/lib/easy_ml/support/est.rb +5 -1
- data/lib/easy_ml/support/file_rotate.rb +79 -15
- data/lib/easy_ml/support/file_support.rb +9 -0
- data/lib/easy_ml/support/local_file.rb +24 -0
- data/lib/easy_ml/support/lockable.rb +62 -0
- data/lib/easy_ml/support/synced_file.rb +103 -0
- data/lib/easy_ml/support/utc.rb +5 -1
- data/lib/easy_ml/support.rb +6 -3
- data/lib/easy_ml/version.rb +4 -1
- data/lib/easy_ml.rb +7 -2
- metadata +355 -72
- data/app/models/easy_ml/models.rb +0 -5
- data/lib/easy_ml/core/model.rb +0 -30
- data/lib/easy_ml/core/model_core.rb +0 -181
- data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
- data/lib/easy_ml/core/models/xgboost.rb +0 -10
- data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
- data/lib/easy_ml/core/models.rb +0 -10
- data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
- data/lib/easy_ml/core/uploaders.rb +0 -7
- data/lib/easy_ml/data/dataloader.rb +0 -6
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
- data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
- data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
- data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
- data/lib/easy_ml/data/dataset/splits.rb +0 -11
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
- data/lib/easy_ml/data/dataset/splitters.rb +0 -9
- data/lib/easy_ml/data/dataset.rb +0 -430
- data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
- data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
- data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
- data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
- data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
- data/lib/easy_ml/data/datasource.rb +0 -33
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
- data/lib/easy_ml/deployment.rb +0 -5
- data/lib/easy_ml/support/synced_directory.rb +0 -134
- data/lib/easy_ml/transforms.rb +0 -29
- /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
data/config/vite.json
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
{
|
2
|
+
"all": {
|
3
|
+
"sourceCodeDir": "app/frontend",
|
4
|
+
"watchAdditionalPaths": [],
|
5
|
+
"publicOutputDir": "easy-ml"
|
6
|
+
},
|
7
|
+
"development": {
|
8
|
+
"autoBuild": true,
|
9
|
+
"port": 3037
|
10
|
+
},
|
11
|
+
"test": {
|
12
|
+
"autoBuild": true,
|
13
|
+
"publicOutputDir": "vite-test"
|
14
|
+
}
|
15
|
+
}
|
@@ -0,0 +1,64 @@
|
|
1
|
+
require "singleton"
|
2
|
+
require_relative "../../app/models/easy_ml/settings"
|
3
|
+
|
4
|
+
module EasyML
|
5
|
+
class Configuration
|
6
|
+
include Singleton
|
7
|
+
|
8
|
+
TIMEZONES = [
|
9
|
+
{ value: "America/New_York", label: "Eastern Time" },
|
10
|
+
{ value: "America/Chicago", label: "Central Time" },
|
11
|
+
{ value: "America/Denver", label: "Mountain Time" },
|
12
|
+
{ value: "America/Los_Angeles", label: "Pacific Time" },
|
13
|
+
]
|
14
|
+
KEYS = EasyML::Settings.configuration_attributes
|
15
|
+
LABELER = {
|
16
|
+
timezone: TIMEZONES,
|
17
|
+
}
|
18
|
+
|
19
|
+
KEYS.each do |key|
|
20
|
+
define_method "#{key}=" do |value|
|
21
|
+
db_settings.send("#{key}=", value)
|
22
|
+
end
|
23
|
+
|
24
|
+
define_method key do
|
25
|
+
db_settings.send(key)
|
26
|
+
end
|
27
|
+
|
28
|
+
if LABELER.key?(key.to_sym)
|
29
|
+
define_method "#{key}_label" do
|
30
|
+
LABELER[key].find { |h| h[:value] == send(key) }[:label]
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
class << self
|
36
|
+
def configure
|
37
|
+
yield instance
|
38
|
+
instance.db_settings.save
|
39
|
+
end
|
40
|
+
|
41
|
+
KEYS.each do |key|
|
42
|
+
define_method key do
|
43
|
+
instance.send(key)
|
44
|
+
end
|
45
|
+
|
46
|
+
if LABELER.key?(key.to_sym)
|
47
|
+
define_method "#{key}_label" do
|
48
|
+
instance.send("#{key}_label")
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
private
|
54
|
+
|
55
|
+
def db_settings
|
56
|
+
instance.db_settings
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
def db_settings
|
61
|
+
@db_settings ||= EasyML::Settings.first_or_create
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
@@ -0,0 +1,53 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Core
|
3
|
+
module Evaluators
|
4
|
+
module BaseEvaluator
|
5
|
+
def self.included(base)
|
6
|
+
base.extend(ClassMethods)
|
7
|
+
end
|
8
|
+
|
9
|
+
def direction
|
10
|
+
"minimize"
|
11
|
+
end
|
12
|
+
|
13
|
+
def label
|
14
|
+
key.split("_").join(" ").titleize
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_option
|
18
|
+
EasyML::Option.new(to_h)
|
19
|
+
end
|
20
|
+
|
21
|
+
def to_h
|
22
|
+
{
|
23
|
+
value: key,
|
24
|
+
label: label,
|
25
|
+
direction: direction
|
26
|
+
}
|
27
|
+
end
|
28
|
+
|
29
|
+
def key
|
30
|
+
self.class.name.split("::").last.underscore
|
31
|
+
end
|
32
|
+
|
33
|
+
# Instance methods that evaluators must implement
|
34
|
+
def evaluate(y_pred: nil, y_true: nil, x_true: nil)
|
35
|
+
raise NotImplementedError, "#{self.class} must implement #evaluate"
|
36
|
+
end
|
37
|
+
|
38
|
+
def calculate_result(metrics)
|
39
|
+
metrics.symbolize_keys!
|
40
|
+
metrics[metric.to_sym]
|
41
|
+
end
|
42
|
+
|
43
|
+
module ClassMethods
|
44
|
+
def self.extended(base)
|
45
|
+
class << base
|
46
|
+
attr_accessor :registry
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,126 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Core
|
3
|
+
module Evaluators
|
4
|
+
module ClassificationEvaluators
|
5
|
+
class AccuracyScore
|
6
|
+
include BaseEvaluator
|
7
|
+
|
8
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
9
|
+
y_pred = Numo::Int32.cast(y_pred)
|
10
|
+
y_true = Numo::Int32.cast(y_true)
|
11
|
+
y_pred.eq(y_true).count_true.to_f / y_pred.size
|
12
|
+
end
|
13
|
+
|
14
|
+
def direction
|
15
|
+
"maximize"
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
class PrecisionScore
|
20
|
+
include BaseEvaluator
|
21
|
+
|
22
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
23
|
+
y_pred = Numo::Int32.cast(y_pred)
|
24
|
+
y_true = Numo::Int32.cast(y_true)
|
25
|
+
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
26
|
+
predicted_positives = y_pred.eq(1).count_true
|
27
|
+
return 0 if predicted_positives.zero?
|
28
|
+
|
29
|
+
true_positives.to_f / predicted_positives
|
30
|
+
end
|
31
|
+
|
32
|
+
def direction
|
33
|
+
"maximize"
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
class RecallScore
|
38
|
+
include BaseEvaluator
|
39
|
+
|
40
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
41
|
+
y_pred = Numo::Int32.cast(y_pred)
|
42
|
+
y_true = Numo::Int32.cast(y_true)
|
43
|
+
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
44
|
+
actual_positives = y_true.eq(1).count_true
|
45
|
+
true_positives.to_f / actual_positives
|
46
|
+
end
|
47
|
+
|
48
|
+
def direction
|
49
|
+
"maximize"
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
class F1Score
|
54
|
+
include BaseEvaluator
|
55
|
+
|
56
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
57
|
+
precision = PrecisionScore.new.evaluate(y_pred: y_pred, y_true: y_true)
|
58
|
+
recall = RecallScore.new.evaluate(y_pred: y_pred, y_true: y_true)
|
59
|
+
return 0 unless (precision + recall) > 0
|
60
|
+
|
61
|
+
2 * (precision * recall) / (precision + recall)
|
62
|
+
end
|
63
|
+
|
64
|
+
def direction
|
65
|
+
"maximize"
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
class AUC
|
70
|
+
include BaseEvaluator
|
71
|
+
|
72
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
73
|
+
y_pred = Numo::DFloat.cast(y_pred)
|
74
|
+
y_true = Numo::Int32.cast(y_true)
|
75
|
+
|
76
|
+
sorted_indices = y_pred.sort_index
|
77
|
+
y_pred[sorted_indices]
|
78
|
+
y_true_sorted = y_true[sorted_indices]
|
79
|
+
|
80
|
+
true_positive_rate = []
|
81
|
+
false_positive_rate = []
|
82
|
+
|
83
|
+
positive_count = y_true_sorted.eq(1).count_true
|
84
|
+
negative_count = y_true_sorted.eq(0).count_true
|
85
|
+
|
86
|
+
tp = 0
|
87
|
+
fp = 0
|
88
|
+
|
89
|
+
y_true_sorted.each do |label|
|
90
|
+
if label == 1
|
91
|
+
tp += 1
|
92
|
+
else
|
93
|
+
fp += 1
|
94
|
+
end
|
95
|
+
true_positive_rate << tp.to_f / positive_count
|
96
|
+
false_positive_rate << fp.to_f / negative_count
|
97
|
+
end
|
98
|
+
|
99
|
+
# Compute the AUC using the trapezoidal rule
|
100
|
+
tpr = Numo::DFloat[*true_positive_rate]
|
101
|
+
fpr = Numo::DFloat[*false_positive_rate]
|
102
|
+
|
103
|
+
auc = ((fpr[1..-1] - fpr[0...-1]) * (tpr[1..-1] + tpr[0...-1]) / 2.0).sum
|
104
|
+
auc
|
105
|
+
end
|
106
|
+
|
107
|
+
def direction
|
108
|
+
"maximize"
|
109
|
+
end
|
110
|
+
end
|
111
|
+
|
112
|
+
class ROC_AUC
|
113
|
+
include BaseEvaluator
|
114
|
+
|
115
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
116
|
+
AUC.new.evaluate(y_pred: y_pred, y_true: y_true)
|
117
|
+
end
|
118
|
+
|
119
|
+
def direction
|
120
|
+
"maximize"
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|
124
|
+
end
|
125
|
+
end
|
126
|
+
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Core
|
3
|
+
module Evaluators
|
4
|
+
module RegressionEvaluators
|
5
|
+
class MeanAbsoluteError
|
6
|
+
include BaseEvaluator
|
7
|
+
|
8
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
9
|
+
(Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
|
10
|
+
end
|
11
|
+
|
12
|
+
def direction
|
13
|
+
"minimize"
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
class MeanSquaredError
|
18
|
+
include BaseEvaluator
|
19
|
+
|
20
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
21
|
+
((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean
|
22
|
+
end
|
23
|
+
|
24
|
+
def direction
|
25
|
+
"minimize"
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
class RootMeanSquaredError
|
30
|
+
include BaseEvaluator
|
31
|
+
|
32
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
33
|
+
Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean)
|
34
|
+
end
|
35
|
+
|
36
|
+
def direction
|
37
|
+
"minimize"
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
class R2Score
|
42
|
+
include BaseEvaluator
|
43
|
+
|
44
|
+
def direction
|
45
|
+
"maximize"
|
46
|
+
end
|
47
|
+
|
48
|
+
def evaluate(y_pred:, y_true:, x_true: nil)
|
49
|
+
y_true = Numo::DFloat.cast(y_true)
|
50
|
+
y_pred = Numo::DFloat.cast(y_pred)
|
51
|
+
|
52
|
+
mean_y = y_true.mean
|
53
|
+
ss_tot = ((y_true - mean_y) ** 2).sum
|
54
|
+
ss_res = ((y_true - y_pred) ** 2).sum
|
55
|
+
|
56
|
+
if ss_tot.zero?
|
57
|
+
ss_res.zero? ? 1.0 : Float::NAN
|
58
|
+
else
|
59
|
+
1 - (ss_res / ss_tot)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -1,78 +1,86 @@
|
|
1
|
+
require "numo/narray"
|
2
|
+
require_relative "evaluators/base_evaluator"
|
3
|
+
require_relative "evaluators/regression_evaluators"
|
4
|
+
require_relative "evaluators/classification_evaluators"
|
5
|
+
|
1
6
|
module EasyML
|
2
7
|
module Core
|
3
8
|
class ModelEvaluator
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
9
|
+
class << self
|
10
|
+
def callbacks=(callback)
|
11
|
+
@callbacks ||= []
|
12
|
+
@callbacks.push(callback)
|
13
|
+
end
|
14
|
+
|
15
|
+
def callbacks
|
16
|
+
@callbacks || []
|
17
|
+
end
|
18
|
+
|
19
|
+
def register(metric_name, evaluator, type, aliases = {})
|
20
|
+
@registry ||= {}
|
21
|
+
unless evaluator.included_modules.include?(Evaluators::BaseEvaluator)
|
22
|
+
evaluator.include(Evaluators::BaseEvaluator)
|
23
|
+
end
|
24
|
+
|
25
|
+
callbacks.each do |callback|
|
26
|
+
callback.call(metric_name)
|
27
|
+
end
|
28
|
+
|
29
|
+
@registry[metric_name.to_sym] = {
|
30
|
+
evaluator: evaluator,
|
31
|
+
type: type,
|
32
|
+
aliases: (aliases || []).map(&:to_sym),
|
33
|
+
}
|
34
|
+
end
|
35
|
+
|
36
|
+
def get(name)
|
37
|
+
return if name.nil?
|
38
|
+
|
39
|
+
@registry ||= {}
|
40
|
+
option = (@registry[name.to_sym] || @registry.detect do |_k, opts|
|
41
|
+
opts[:aliases].include?(name.to_sym)
|
42
|
+
end.last) || {}
|
43
|
+
option.dig(:evaluator)
|
44
|
+
end
|
45
|
+
|
46
|
+
def for_frontend(evaluator)
|
47
|
+
evaluator.new.to_h
|
48
|
+
end
|
49
|
+
|
50
|
+
def default_evaluator(task)
|
51
|
+
{
|
52
|
+
classification: {
|
53
|
+
metric: "accuracy_score",
|
54
|
+
threshold: 0.70,
|
55
|
+
direction: "maximize",
|
56
|
+
},
|
57
|
+
regression: {
|
58
|
+
metric: "root_mean_squared_error",
|
59
|
+
threshold: 10,
|
60
|
+
direction: "minimize",
|
61
|
+
},
|
62
|
+
}[task.to_sym]
|
63
|
+
end
|
64
|
+
|
65
|
+
def metrics_by_task
|
66
|
+
@registry.group_by { |_key, metric| metric[:type] }.transform_values do |group|
|
67
|
+
group.flat_map do |metric|
|
68
|
+
for_frontend(metric.last.dig(:evaluator))
|
38
69
|
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
|
73
|
+
def metrics(task = nil)
|
74
|
+
if task.nil?
|
75
|
+
@registry.keys
|
39
76
|
else
|
40
|
-
|
41
|
-
|
77
|
+
@registry.select do |_k, v|
|
78
|
+
v[:type].to_sym == task.to_sym
|
79
|
+
end.keys
|
42
80
|
end
|
43
|
-
|
44
|
-
accuracy_score: lambda { |y_pred, y_true|
|
45
|
-
y_pred = Numo::Int32.cast(y_pred)
|
46
|
-
y_true = Numo::Int32.cast(y_true)
|
47
|
-
y_pred.eq(y_true).count_true.to_f / y_pred.size
|
48
|
-
},
|
49
|
-
precision_score: lambda { |y_pred, y_true|
|
50
|
-
y_pred = Numo::Int32.cast(y_pred)
|
51
|
-
y_true = Numo::Int32.cast(y_true)
|
52
|
-
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
53
|
-
predicted_positives = y_pred.eq(1).count_true
|
54
|
-
return 0 if predicted_positives == 0
|
55
|
-
|
56
|
-
true_positives.to_f / predicted_positives
|
57
|
-
},
|
58
|
-
recall_score: lambda { |y_pred, y_true|
|
59
|
-
y_pred = Numo::Int32.cast(y_pred)
|
60
|
-
y_true = Numo::Int32.cast(y_true)
|
61
|
-
true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
|
62
|
-
actual_positives = y_true.eq(1).count_true
|
63
|
-
true_positives.to_f / actual_positives
|
64
|
-
},
|
65
|
-
f1_score: lambda { |y_pred, y_true|
|
66
|
-
precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
|
67
|
-
recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
|
68
|
-
return 0 unless (precision + recall) > 0
|
69
|
-
|
70
|
-
2 * (precision * recall) / (precision + recall)
|
71
|
-
}
|
72
|
-
}
|
81
|
+
end
|
73
82
|
|
74
|
-
|
75
|
-
def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
|
83
|
+
def evaluate(model:, y_pred:, y_true:, x_true: nil, evaluator: nil)
|
76
84
|
y_pred = normalize_input(y_pred)
|
77
85
|
y_true = normalize_input(y_true)
|
78
86
|
check_size(y_pred, y_true)
|
@@ -80,45 +88,46 @@ module EasyML
|
|
80
88
|
metrics_results = {}
|
81
89
|
|
82
90
|
model.metrics.each do |metric|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
91
|
+
evaluator_class = get(metric.to_sym)
|
92
|
+
next unless evaluator_class
|
93
|
+
|
94
|
+
evaluator_instance = evaluator_class.new
|
95
|
+
|
96
|
+
metrics_results[metric.to_sym] = evaluator_instance.evaluate(
|
97
|
+
y_pred: y_pred,
|
98
|
+
y_true: y_true,
|
99
|
+
x_true: x_true,
|
100
|
+
)
|
93
101
|
end
|
94
102
|
|
95
103
|
if evaluator.present?
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
else
|
103
|
-
raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
|
104
|
-
end
|
104
|
+
evaluator = evaluator.symbolize_keys!
|
105
|
+
evaluator_class = get(evaluator[:metric])
|
106
|
+
raise "Unknown evaluator: #{evaluator}" unless evaluator_class
|
107
|
+
|
108
|
+
evaluator_instance = evaluator_class.new
|
109
|
+
response = evaluator_instance.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
|
105
110
|
|
106
111
|
if response.is_a?(Hash)
|
107
112
|
metrics_results.merge!(response)
|
108
113
|
else
|
109
|
-
metrics_results[:
|
114
|
+
metrics_results[evaluator[:metric].to_sym] = response
|
110
115
|
end
|
111
116
|
end
|
112
117
|
|
113
|
-
metrics_results
|
118
|
+
metrics_results.symbolize_keys
|
114
119
|
end
|
115
120
|
|
121
|
+
private
|
122
|
+
|
116
123
|
def check_size(y_pred, y_true)
|
117
124
|
raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
|
118
125
|
end
|
119
126
|
|
120
127
|
def normalize_input(input)
|
121
128
|
case input
|
129
|
+
when Array
|
130
|
+
Numo::DFloat.cast(input)
|
122
131
|
when Polars::DataFrame
|
123
132
|
if input.columns.count > 1
|
124
133
|
raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
|
@@ -135,3 +144,66 @@ module EasyML
|
|
135
144
|
end
|
136
145
|
end
|
137
146
|
end
|
147
|
+
|
148
|
+
# Register default evaluators
|
149
|
+
EasyML::Core::ModelEvaluator.register(
|
150
|
+
:mean_absolute_error,
|
151
|
+
EasyML::Core::Evaluators::RegressionEvaluators::MeanAbsoluteError,
|
152
|
+
:regression,
|
153
|
+
%w[mae]
|
154
|
+
)
|
155
|
+
EasyML::Core::ModelEvaluator.register(
|
156
|
+
:mean_squared_error,
|
157
|
+
EasyML::Core::Evaluators::RegressionEvaluators::MeanSquaredError,
|
158
|
+
:regression,
|
159
|
+
%w[mse]
|
160
|
+
)
|
161
|
+
EasyML::Core::ModelEvaluator.register(
|
162
|
+
:root_mean_squared_error,
|
163
|
+
EasyML::Core::Evaluators::RegressionEvaluators::RootMeanSquaredError,
|
164
|
+
:regression,
|
165
|
+
%w[rmse]
|
166
|
+
)
|
167
|
+
|
168
|
+
EasyML::Core::ModelEvaluator.register(
|
169
|
+
:r2_score,
|
170
|
+
EasyML::Core::Evaluators::RegressionEvaluators::R2Score,
|
171
|
+
:regression,
|
172
|
+
%w[r2]
|
173
|
+
)
|
174
|
+
EasyML::Core::ModelEvaluator.register(
|
175
|
+
:accuracy_score,
|
176
|
+
EasyML::Core::Evaluators::ClassificationEvaluators::AccuracyScore,
|
177
|
+
:classification,
|
178
|
+
%w[accuracy]
|
179
|
+
)
|
180
|
+
EasyML::Core::ModelEvaluator.register(
|
181
|
+
:precision_score,
|
182
|
+
EasyML::Core::Evaluators::ClassificationEvaluators::PrecisionScore,
|
183
|
+
:classification,
|
184
|
+
%w[precision]
|
185
|
+
)
|
186
|
+
EasyML::Core::ModelEvaluator.register(
|
187
|
+
:recall_score,
|
188
|
+
EasyML::Core::Evaluators::ClassificationEvaluators::RecallScore,
|
189
|
+
:classification,
|
190
|
+
%w[recall]
|
191
|
+
)
|
192
|
+
EasyML::Core::ModelEvaluator.register(
|
193
|
+
:f1_score,
|
194
|
+
EasyML::Core::Evaluators::ClassificationEvaluators::F1Score,
|
195
|
+
:classification,
|
196
|
+
%w[f1]
|
197
|
+
)
|
198
|
+
# EasyML::Core::ModelEvaluator.register(
|
199
|
+
# :auc,
|
200
|
+
# EasyML::Core::Evaluators::ClassificationEvaluators::AUC,
|
201
|
+
# :classification,
|
202
|
+
# %w[auc]
|
203
|
+
# )
|
204
|
+
# EasyML::Core::ModelEvaluator.register(
|
205
|
+
# :roc_auc,
|
206
|
+
# EasyML::Core::Evaluators::ClassificationEvaluators::ROC_AUC,
|
207
|
+
# :classification,
|
208
|
+
# %w[roc_auc]
|
209
|
+
# )
|
@@ -3,39 +3,43 @@ module EasyML
|
|
3
3
|
class Tuner
|
4
4
|
module Adapters
|
5
5
|
class BaseAdapter
|
6
|
-
|
6
|
+
attr_accessor :config, :project_name, :tune_started_at, :model,
|
7
|
+
:x_true, :y_true, :metadata, :model
|
8
|
+
|
9
|
+
def initialize(options = {})
|
10
|
+
@model = options[:model]
|
11
|
+
@config = options[:config] || {}
|
12
|
+
@project_name = options[:project_name]
|
13
|
+
@tune_started_at = options[:tune_started_at]
|
14
|
+
@model = options[:model]
|
15
|
+
@x_true = options[:x_true]
|
16
|
+
@y_true = options[:y_true]
|
17
|
+
@metadata = options[:metadata] || {}
|
18
|
+
end
|
7
19
|
|
8
20
|
def defaults
|
9
21
|
{}
|
10
22
|
end
|
11
23
|
|
12
|
-
attribute :model
|
13
|
-
attribute :config, :hash
|
14
|
-
attribute :project_name, :string
|
15
|
-
attribute :tune_started_at
|
16
|
-
attribute :y_true
|
17
|
-
attribute :x_true
|
18
|
-
|
19
24
|
def run_trial(trial)
|
20
|
-
config = deep_merge_defaults(self.config.clone)
|
25
|
+
config = deep_merge_defaults(self.config.clone.deep_symbolize_keys)
|
21
26
|
suggest_parameters(trial, config)
|
22
|
-
model.fit
|
23
27
|
yield model
|
24
28
|
end
|
25
29
|
|
26
|
-
def configure_callbacks
|
27
|
-
raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
|
28
|
-
end
|
29
|
-
|
30
30
|
def suggest_parameters(trial, config)
|
31
|
-
|
32
|
-
|
33
|
-
|
31
|
+
config.keys.inject({}) do |hash, param_name|
|
32
|
+
hash.tap do
|
33
|
+
param_value = suggest_parameter(trial, param_name, config)
|
34
|
+
puts "Suggesting #{param_name}: #{param_value}"
|
35
|
+
model.hyperparameters.send("#{param_name}=", param_value)
|
36
|
+
hash[param_name] = param_value
|
37
|
+
end
|
34
38
|
end
|
35
39
|
end
|
36
40
|
|
37
41
|
def deep_merge_defaults(config)
|
38
|
-
defaults.deep_merge(config) do |_key, default_value, config_value|
|
42
|
+
defaults.deep_symbolize_keys.deep_merge(config.deep_symbolize_keys) do |_key, default_value, config_value|
|
39
43
|
if default_value.is_a?(Hash) && config_value.is_a?(Hash)
|
40
44
|
default_value.merge(config_value)
|
41
45
|
else
|
@@ -46,12 +50,18 @@ module EasyML
|
|
46
50
|
|
47
51
|
def suggest_parameter(trial, param_name, config)
|
48
52
|
param_config = config[param_name]
|
53
|
+
if !param_config.is_a?(Hash)
|
54
|
+
return param_config
|
55
|
+
end
|
56
|
+
|
49
57
|
min = param_config[:min]
|
50
58
|
max = param_config[:max]
|
51
59
|
log = param_config[:log]
|
52
60
|
|
53
61
|
if log
|
54
62
|
trial.suggest_loguniform(param_name.to_s, min, max)
|
63
|
+
elsif max.is_a?(Integer) && min.is_a?(Integer)
|
64
|
+
trial.suggest_int(param_name.to_s, min, max)
|
55
65
|
else
|
56
66
|
trial.suggest_uniform(param_name.to_s, min, max)
|
57
67
|
end
|