easy_ml 0.1.3 → 0.2.0.pre.rc1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -4
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
data/config/vite.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "all": {
3
+ "sourceCodeDir": "app/frontend",
4
+ "watchAdditionalPaths": [],
5
+ "publicOutputDir": "easy-ml"
6
+ },
7
+ "development": {
8
+ "autoBuild": true,
9
+ "port": 3037
10
+ },
11
+ "test": {
12
+ "autoBuild": true,
13
+ "publicOutputDir": "vite-test"
14
+ }
15
+ }
@@ -0,0 +1,64 @@
1
+ require "singleton"
2
+ require_relative "../../app/models/easy_ml/settings"
3
+
4
+ module EasyML
5
+ class Configuration
6
+ include Singleton
7
+
8
+ TIMEZONES = [
9
+ { value: "America/New_York", label: "Eastern Time" },
10
+ { value: "America/Chicago", label: "Central Time" },
11
+ { value: "America/Denver", label: "Mountain Time" },
12
+ { value: "America/Los_Angeles", label: "Pacific Time" },
13
+ ]
14
+ KEYS = EasyML::Settings.configuration_attributes
15
+ LABELER = {
16
+ timezone: TIMEZONES,
17
+ }
18
+
19
+ KEYS.each do |key|
20
+ define_method "#{key}=" do |value|
21
+ db_settings.send("#{key}=", value)
22
+ end
23
+
24
+ define_method key do
25
+ db_settings.send(key)
26
+ end
27
+
28
+ if LABELER.key?(key.to_sym)
29
+ define_method "#{key}_label" do
30
+ LABELER[key].find { |h| h[:value] == send(key) }[:label]
31
+ end
32
+ end
33
+ end
34
+
35
+ class << self
36
+ def configure
37
+ yield instance
38
+ instance.db_settings.save
39
+ end
40
+
41
+ KEYS.each do |key|
42
+ define_method key do
43
+ instance.send(key)
44
+ end
45
+
46
+ if LABELER.key?(key.to_sym)
47
+ define_method "#{key}_label" do
48
+ instance.send("#{key}_label")
49
+ end
50
+ end
51
+ end
52
+
53
+ private
54
+
55
+ def db_settings
56
+ instance.db_settings
57
+ end
58
+ end
59
+
60
+ def db_settings
61
+ @db_settings ||= EasyML::Settings.first_or_create
62
+ end
63
+ end
64
+ end
@@ -0,0 +1,53 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module BaseEvaluator
5
+ def self.included(base)
6
+ base.extend(ClassMethods)
7
+ end
8
+
9
+ def direction
10
+ "minimize"
11
+ end
12
+
13
+ def label
14
+ key.split("_").join(" ").titleize
15
+ end
16
+
17
+ def to_option
18
+ EasyML::Option.new(to_h)
19
+ end
20
+
21
+ def to_h
22
+ {
23
+ value: key,
24
+ label: label,
25
+ direction: direction
26
+ }
27
+ end
28
+
29
+ def key
30
+ self.class.name.split("::").last.underscore
31
+ end
32
+
33
+ # Instance methods that evaluators must implement
34
+ def evaluate(y_pred: nil, y_true: nil, x_true: nil)
35
+ raise NotImplementedError, "#{self.class} must implement #evaluate"
36
+ end
37
+
38
+ def calculate_result(metrics)
39
+ metrics.symbolize_keys!
40
+ metrics[metric.to_sym]
41
+ end
42
+
43
+ module ClassMethods
44
+ def self.extended(base)
45
+ class << base
46
+ attr_accessor :registry
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,126 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module ClassificationEvaluators
5
+ class AccuracyScore
6
+ include BaseEvaluator
7
+
8
+ def evaluate(y_pred:, y_true:, x_true: nil)
9
+ y_pred = Numo::Int32.cast(y_pred)
10
+ y_true = Numo::Int32.cast(y_true)
11
+ y_pred.eq(y_true).count_true.to_f / y_pred.size
12
+ end
13
+
14
+ def direction
15
+ "maximize"
16
+ end
17
+ end
18
+
19
+ class PrecisionScore
20
+ include BaseEvaluator
21
+
22
+ def evaluate(y_pred:, y_true:, x_true: nil)
23
+ y_pred = Numo::Int32.cast(y_pred)
24
+ y_true = Numo::Int32.cast(y_true)
25
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
26
+ predicted_positives = y_pred.eq(1).count_true
27
+ return 0 if predicted_positives.zero?
28
+
29
+ true_positives.to_f / predicted_positives
30
+ end
31
+
32
+ def direction
33
+ "maximize"
34
+ end
35
+ end
36
+
37
+ class RecallScore
38
+ include BaseEvaluator
39
+
40
+ def evaluate(y_pred:, y_true:, x_true: nil)
41
+ y_pred = Numo::Int32.cast(y_pred)
42
+ y_true = Numo::Int32.cast(y_true)
43
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
44
+ actual_positives = y_true.eq(1).count_true
45
+ true_positives.to_f / actual_positives
46
+ end
47
+
48
+ def direction
49
+ "maximize"
50
+ end
51
+ end
52
+
53
+ class F1Score
54
+ include BaseEvaluator
55
+
56
+ def evaluate(y_pred:, y_true:, x_true: nil)
57
+ precision = PrecisionScore.new.evaluate(y_pred: y_pred, y_true: y_true)
58
+ recall = RecallScore.new.evaluate(y_pred: y_pred, y_true: y_true)
59
+ return 0 unless (precision + recall) > 0
60
+
61
+ 2 * (precision * recall) / (precision + recall)
62
+ end
63
+
64
+ def direction
65
+ "maximize"
66
+ end
67
+ end
68
+
69
+ class AUC
70
+ include BaseEvaluator
71
+
72
+ def evaluate(y_pred:, y_true:, x_true: nil)
73
+ y_pred = Numo::DFloat.cast(y_pred)
74
+ y_true = Numo::Int32.cast(y_true)
75
+
76
+ sorted_indices = y_pred.sort_index
77
+ y_pred[sorted_indices]
78
+ y_true_sorted = y_true[sorted_indices]
79
+
80
+ true_positive_rate = []
81
+ false_positive_rate = []
82
+
83
+ positive_count = y_true_sorted.eq(1).count_true
84
+ negative_count = y_true_sorted.eq(0).count_true
85
+
86
+ tp = 0
87
+ fp = 0
88
+
89
+ y_true_sorted.each do |label|
90
+ if label == 1
91
+ tp += 1
92
+ else
93
+ fp += 1
94
+ end
95
+ true_positive_rate << tp.to_f / positive_count
96
+ false_positive_rate << fp.to_f / negative_count
97
+ end
98
+
99
+ # Compute the AUC using the trapezoidal rule
100
+ tpr = Numo::DFloat[*true_positive_rate]
101
+ fpr = Numo::DFloat[*false_positive_rate]
102
+
103
+ auc = ((fpr[1..-1] - fpr[0...-1]) * (tpr[1..-1] + tpr[0...-1]) / 2.0).sum
104
+ auc
105
+ end
106
+
107
+ def direction
108
+ "maximize"
109
+ end
110
+ end
111
+
112
+ class ROC_AUC
113
+ include BaseEvaluator
114
+
115
+ def evaluate(y_pred:, y_true:, x_true: nil)
116
+ AUC.new.evaluate(y_pred: y_pred, y_true: y_true)
117
+ end
118
+
119
+ def direction
120
+ "maximize"
121
+ end
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,66 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module RegressionEvaluators
5
+ class MeanAbsoluteError
6
+ include BaseEvaluator
7
+
8
+ def evaluate(y_pred:, y_true:, x_true: nil)
9
+ (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
10
+ end
11
+
12
+ def direction
13
+ "minimize"
14
+ end
15
+ end
16
+
17
+ class MeanSquaredError
18
+ include BaseEvaluator
19
+
20
+ def evaluate(y_pred:, y_true:, x_true: nil)
21
+ ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean
22
+ end
23
+
24
+ def direction
25
+ "minimize"
26
+ end
27
+ end
28
+
29
+ class RootMeanSquaredError
30
+ include BaseEvaluator
31
+
32
+ def evaluate(y_pred:, y_true:, x_true: nil)
33
+ Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean)
34
+ end
35
+
36
+ def direction
37
+ "minimize"
38
+ end
39
+ end
40
+
41
+ class R2Score
42
+ include BaseEvaluator
43
+
44
+ def direction
45
+ "maximize"
46
+ end
47
+
48
+ def evaluate(y_pred:, y_true:, x_true: nil)
49
+ y_true = Numo::DFloat.cast(y_true)
50
+ y_pred = Numo::DFloat.cast(y_pred)
51
+
52
+ mean_y = y_true.mean
53
+ ss_tot = ((y_true - mean_y) ** 2).sum
54
+ ss_res = ((y_true - y_pred) ** 2).sum
55
+
56
+ if ss_tot.zero?
57
+ ss_res.zero? ? 1.0 : Float::NAN
58
+ else
59
+ 1 - (ss_res / ss_tot)
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
65
+ end
66
+ end
@@ -1,78 +1,86 @@
1
+ require "numo/narray"
2
+ require_relative "evaluators/base_evaluator"
3
+ require_relative "evaluators/regression_evaluators"
4
+ require_relative "evaluators/classification_evaluators"
5
+
1
6
  module EasyML
2
7
  module Core
3
8
  class ModelEvaluator
4
- require "numo/narray"
5
-
6
- EVALUATORS = {
7
- mean_absolute_error: lambda { |y_pred, y_true|
8
- (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
9
- },
10
- mean_squared_error: lambda { |y_pred, y_true|
11
- ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean
12
- },
13
- root_mean_squared_error: lambda { |y_pred, y_true|
14
- Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean)
15
- },
16
- r2_score: lambda { |y_pred, y_true|
17
- # Convert inputs to Numo::DFloat for numerical operations
18
- y_true = Numo::DFloat.cast(y_true)
19
- y_pred = Numo::DFloat.cast(y_pred)
20
-
21
- # Calculate the mean of the true values
22
- mean_y = y_true.mean
23
-
24
- # Calculate Total Sum of Squares (SS_tot)
25
- ss_tot = ((y_true - mean_y)**2).sum
26
-
27
- # Calculate Residual Sum of Squares (SS_res)
28
- ss_res = ((y_true - y_pred)**2).sum
29
-
30
- # Handle the edge case where SS_tot is zero
31
- if ss_tot.zero?
32
- if ss_res.zero?
33
- # Perfect prediction when both SS_tot and SS_res are zero
34
- 1.0
35
- else
36
- # Undefined R² when SS_tot is zero but SS_res is not
37
- Float::NAN
9
+ class << self
10
+ def callbacks=(callback)
11
+ @callbacks ||= []
12
+ @callbacks.push(callback)
13
+ end
14
+
15
+ def callbacks
16
+ @callbacks || []
17
+ end
18
+
19
+ def register(metric_name, evaluator, type, aliases = {})
20
+ @registry ||= {}
21
+ unless evaluator.included_modules.include?(Evaluators::BaseEvaluator)
22
+ evaluator.include(Evaluators::BaseEvaluator)
23
+ end
24
+
25
+ callbacks.each do |callback|
26
+ callback.call(metric_name)
27
+ end
28
+
29
+ @registry[metric_name.to_sym] = {
30
+ evaluator: evaluator,
31
+ type: type,
32
+ aliases: (aliases || []).map(&:to_sym),
33
+ }
34
+ end
35
+
36
+ def get(name)
37
+ return if name.nil?
38
+
39
+ @registry ||= {}
40
+ option = (@registry[name.to_sym] || @registry.detect do |_k, opts|
41
+ opts[:aliases].include?(name.to_sym)
42
+ end.last) || {}
43
+ option.dig(:evaluator)
44
+ end
45
+
46
+ def for_frontend(evaluator)
47
+ evaluator.new.to_h
48
+ end
49
+
50
+ def default_evaluator(task)
51
+ {
52
+ classification: {
53
+ metric: "accuracy_score",
54
+ threshold: 0.70,
55
+ direction: "maximize",
56
+ },
57
+ regression: {
58
+ metric: "root_mean_squared_error",
59
+ threshold: 10,
60
+ direction: "minimize",
61
+ },
62
+ }[task.to_sym]
63
+ end
64
+
65
+ def metrics_by_task
66
+ @registry.group_by { |_key, metric| metric[:type] }.transform_values do |group|
67
+ group.flat_map do |metric|
68
+ for_frontend(metric.last.dig(:evaluator))
38
69
  end
70
+ end
71
+ end
72
+
73
+ def metrics(task = nil)
74
+ if task.nil?
75
+ @registry.keys
39
76
  else
40
- # Calculate
41
- 1 - (ss_res / ss_tot)
77
+ @registry.select do |_k, v|
78
+ v[:type].to_sym == task.to_sym
79
+ end.keys
42
80
  end
43
- },
44
- accuracy_score: lambda { |y_pred, y_true|
45
- y_pred = Numo::Int32.cast(y_pred)
46
- y_true = Numo::Int32.cast(y_true)
47
- y_pred.eq(y_true).count_true.to_f / y_pred.size
48
- },
49
- precision_score: lambda { |y_pred, y_true|
50
- y_pred = Numo::Int32.cast(y_pred)
51
- y_true = Numo::Int32.cast(y_true)
52
- true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
53
- predicted_positives = y_pred.eq(1).count_true
54
- return 0 if predicted_positives == 0
55
-
56
- true_positives.to_f / predicted_positives
57
- },
58
- recall_score: lambda { |y_pred, y_true|
59
- y_pred = Numo::Int32.cast(y_pred)
60
- y_true = Numo::Int32.cast(y_true)
61
- true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
62
- actual_positives = y_true.eq(1).count_true
63
- true_positives.to_f / actual_positives
64
- },
65
- f1_score: lambda { |y_pred, y_true|
66
- precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
67
- recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
68
- return 0 unless (precision + recall) > 0
69
-
70
- 2 * (precision * recall) / (precision + recall)
71
- }
72
- }
81
+ end
73
82
 
74
- class << self
75
- def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
83
+ def evaluate(model:, y_pred:, y_true:, x_true: nil, evaluator: nil)
76
84
  y_pred = normalize_input(y_pred)
77
85
  y_true = normalize_input(y_true)
78
86
  check_size(y_pred, y_true)
@@ -80,45 +88,46 @@ module EasyML
80
88
  metrics_results = {}
81
89
 
82
90
  model.metrics.each do |metric|
83
- if metric.is_a?(Module) || metric.is_a?(Class)
84
- unless metric.respond_to?(:evaluate)
85
- raise "Metric #{metric} must respond to #evaluate in order to be used as a custom evaluator"
86
- end
87
-
88
- metrics_results[metric.name] = metric.evaluate(y_pred, y_true)
89
- elsif EVALUATORS.key?(metric.to_sym)
90
- metrics_results[metric.to_sym] =
91
- EVALUATORS[metric.to_sym].call(y_pred, y_true)
92
- end
91
+ evaluator_class = get(metric.to_sym)
92
+ next unless evaluator_class
93
+
94
+ evaluator_instance = evaluator_class.new
95
+
96
+ metrics_results[metric.to_sym] = evaluator_instance.evaluate(
97
+ y_pred: y_pred,
98
+ y_true: y_true,
99
+ x_true: x_true,
100
+ )
93
101
  end
94
102
 
95
103
  if evaluator.present?
96
- if evaluator.is_a?(Class)
97
- response = evaluator.new.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
98
- elsif evaluator.respond_to?(:evaluate)
99
- response = evaluator.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
100
- elsif evaluator.respond_to?(:call)
101
- response = evaluator.call(y_pred: y_pred, y_true: y_true, x_true: x_true)
102
- else
103
- raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
104
- end
104
+ evaluator = evaluator.symbolize_keys!
105
+ evaluator_class = get(evaluator[:metric])
106
+ raise "Unknown evaluator: #{evaluator}" unless evaluator_class
107
+
108
+ evaluator_instance = evaluator_class.new
109
+ response = evaluator_instance.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
105
110
 
106
111
  if response.is_a?(Hash)
107
112
  metrics_results.merge!(response)
108
113
  else
109
- metrics_results[:custom] = response
114
+ metrics_results[evaluator[:metric].to_sym] = response
110
115
  end
111
116
  end
112
117
 
113
- metrics_results
118
+ metrics_results.symbolize_keys
114
119
  end
115
120
 
121
+ private
122
+
116
123
  def check_size(y_pred, y_true)
117
124
  raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
118
125
  end
119
126
 
120
127
  def normalize_input(input)
121
128
  case input
129
+ when Array
130
+ Numo::DFloat.cast(input)
122
131
  when Polars::DataFrame
123
132
  if input.columns.count > 1
124
133
  raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
@@ -135,3 +144,66 @@ module EasyML
135
144
  end
136
145
  end
137
146
  end
147
+
148
+ # Register default evaluators
149
+ EasyML::Core::ModelEvaluator.register(
150
+ :mean_absolute_error,
151
+ EasyML::Core::Evaluators::RegressionEvaluators::MeanAbsoluteError,
152
+ :regression,
153
+ %w[mae]
154
+ )
155
+ EasyML::Core::ModelEvaluator.register(
156
+ :mean_squared_error,
157
+ EasyML::Core::Evaluators::RegressionEvaluators::MeanSquaredError,
158
+ :regression,
159
+ %w[mse]
160
+ )
161
+ EasyML::Core::ModelEvaluator.register(
162
+ :root_mean_squared_error,
163
+ EasyML::Core::Evaluators::RegressionEvaluators::RootMeanSquaredError,
164
+ :regression,
165
+ %w[rmse]
166
+ )
167
+
168
+ EasyML::Core::ModelEvaluator.register(
169
+ :r2_score,
170
+ EasyML::Core::Evaluators::RegressionEvaluators::R2Score,
171
+ :regression,
172
+ %w[r2]
173
+ )
174
+ EasyML::Core::ModelEvaluator.register(
175
+ :accuracy_score,
176
+ EasyML::Core::Evaluators::ClassificationEvaluators::AccuracyScore,
177
+ :classification,
178
+ %w[accuracy]
179
+ )
180
+ EasyML::Core::ModelEvaluator.register(
181
+ :precision_score,
182
+ EasyML::Core::Evaluators::ClassificationEvaluators::PrecisionScore,
183
+ :classification,
184
+ %w[precision]
185
+ )
186
+ EasyML::Core::ModelEvaluator.register(
187
+ :recall_score,
188
+ EasyML::Core::Evaluators::ClassificationEvaluators::RecallScore,
189
+ :classification,
190
+ %w[recall]
191
+ )
192
+ EasyML::Core::ModelEvaluator.register(
193
+ :f1_score,
194
+ EasyML::Core::Evaluators::ClassificationEvaluators::F1Score,
195
+ :classification,
196
+ %w[f1]
197
+ )
198
+ # EasyML::Core::ModelEvaluator.register(
199
+ # :auc,
200
+ # EasyML::Core::Evaluators::ClassificationEvaluators::AUC,
201
+ # :classification,
202
+ # %w[auc]
203
+ # )
204
+ # EasyML::Core::ModelEvaluator.register(
205
+ # :roc_auc,
206
+ # EasyML::Core::Evaluators::ClassificationEvaluators::ROC_AUC,
207
+ # :classification,
208
+ # %w[roc_auc]
209
+ # )
@@ -3,39 +3,43 @@ module EasyML
3
3
  class Tuner
4
4
  module Adapters
5
5
  class BaseAdapter
6
- include GlueGun::DSL
6
+ attr_accessor :config, :project_name, :tune_started_at, :model,
7
+ :x_true, :y_true, :metadata, :model
8
+
9
+ def initialize(options = {})
10
+ @model = options[:model]
11
+ @config = options[:config] || {}
12
+ @project_name = options[:project_name]
13
+ @tune_started_at = options[:tune_started_at]
14
+ @model = options[:model]
15
+ @x_true = options[:x_true]
16
+ @y_true = options[:y_true]
17
+ @metadata = options[:metadata] || {}
18
+ end
7
19
 
8
20
  def defaults
9
21
  {}
10
22
  end
11
23
 
12
- attribute :model
13
- attribute :config, :hash
14
- attribute :project_name, :string
15
- attribute :tune_started_at
16
- attribute :y_true
17
- attribute :x_true
18
-
19
24
  def run_trial(trial)
20
- config = deep_merge_defaults(self.config.clone)
25
+ config = deep_merge_defaults(self.config.clone.deep_symbolize_keys)
21
26
  suggest_parameters(trial, config)
22
- model.fit
23
27
  yield model
24
28
  end
25
29
 
26
- def configure_callbacks
27
- raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
28
- end
29
-
30
30
  def suggest_parameters(trial, config)
31
- defaults.keys.each do |param_name|
32
- param_value = suggest_parameter(trial, param_name, config)
33
- model.hyperparameters.send("#{param_name}=", param_value)
31
+ config.keys.inject({}) do |hash, param_name|
32
+ hash.tap do
33
+ param_value = suggest_parameter(trial, param_name, config)
34
+ puts "Suggesting #{param_name}: #{param_value}"
35
+ model.hyperparameters.send("#{param_name}=", param_value)
36
+ hash[param_name] = param_value
37
+ end
34
38
  end
35
39
  end
36
40
 
37
41
  def deep_merge_defaults(config)
38
- defaults.deep_merge(config) do |_key, default_value, config_value|
42
+ defaults.deep_symbolize_keys.deep_merge(config.deep_symbolize_keys) do |_key, default_value, config_value|
39
43
  if default_value.is_a?(Hash) && config_value.is_a?(Hash)
40
44
  default_value.merge(config_value)
41
45
  else
@@ -46,12 +50,18 @@ module EasyML
46
50
 
47
51
  def suggest_parameter(trial, param_name, config)
48
52
  param_config = config[param_name]
53
+ if !param_config.is_a?(Hash)
54
+ return param_config
55
+ end
56
+
49
57
  min = param_config[:min]
50
58
  max = param_config[:max]
51
59
  log = param_config[:log]
52
60
 
53
61
  if log
54
62
  trial.suggest_loguniform(param_name.to_s, min, max)
63
+ elsif max.is_a?(Integer) && min.is_a?(Integer)
64
+ trial.suggest_int(param_name.to_s, min, max)
55
65
  else
56
66
  trial.suggest_uniform(param_name.to_s, min, max)
57
67
  end