easy_ml 0.1.4 → 0.2.0.pre.rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
data/config/vite.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "all": {
3
+ "sourceCodeDir": "app/frontend",
4
+ "watchAdditionalPaths": [],
5
+ "publicOutputDir": "easy-ml"
6
+ },
7
+ "development": {
8
+ "autoBuild": true,
9
+ "port": 3037
10
+ },
11
+ "test": {
12
+ "autoBuild": true,
13
+ "publicOutputDir": "vite-test"
14
+ }
15
+ }
@@ -0,0 +1,64 @@
1
+ require "singleton"
2
+ require_relative "../../app/models/easy_ml/settings"
3
+
4
+ module EasyML
5
+ class Configuration
6
+ include Singleton
7
+
8
+ TIMEZONES = [
9
+ { value: "America/New_York", label: "Eastern Time" },
10
+ { value: "America/Chicago", label: "Central Time" },
11
+ { value: "America/Denver", label: "Mountain Time" },
12
+ { value: "America/Los_Angeles", label: "Pacific Time" },
13
+ ]
14
+ KEYS = EasyML::Settings.configuration_attributes
15
+ LABELER = {
16
+ timezone: TIMEZONES,
17
+ }
18
+
19
+ KEYS.each do |key|
20
+ define_method "#{key}=" do |value|
21
+ db_settings.send("#{key}=", value)
22
+ end
23
+
24
+ define_method key do
25
+ db_settings.send(key)
26
+ end
27
+
28
+ if LABELER.key?(key.to_sym)
29
+ define_method "#{key}_label" do
30
+ LABELER[key].find { |h| h[:value] == send(key) }[:label]
31
+ end
32
+ end
33
+ end
34
+
35
+ class << self
36
+ def configure
37
+ yield instance
38
+ instance.db_settings.save
39
+ end
40
+
41
+ KEYS.each do |key|
42
+ define_method key do
43
+ instance.send(key)
44
+ end
45
+
46
+ if LABELER.key?(key.to_sym)
47
+ define_method "#{key}_label" do
48
+ instance.send("#{key}_label")
49
+ end
50
+ end
51
+ end
52
+
53
+ private
54
+
55
+ def db_settings
56
+ instance.db_settings
57
+ end
58
+ end
59
+
60
+ def db_settings
61
+ @db_settings ||= EasyML::Settings.first_or_create
62
+ end
63
+ end
64
+ end
@@ -0,0 +1,53 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module BaseEvaluator
5
+ def self.included(base)
6
+ base.extend(ClassMethods)
7
+ end
8
+
9
+ def direction
10
+ "minimize"
11
+ end
12
+
13
+ def label
14
+ key.split("_").join(" ").titleize
15
+ end
16
+
17
+ def to_option
18
+ EasyML::Option.new(to_h)
19
+ end
20
+
21
+ def to_h
22
+ {
23
+ value: key,
24
+ label: label,
25
+ direction: direction
26
+ }
27
+ end
28
+
29
+ def key
30
+ self.class.name.split("::").last.underscore
31
+ end
32
+
33
+ # Instance methods that evaluators must implement
34
+ def evaluate(y_pred: nil, y_true: nil, x_true: nil)
35
+ raise NotImplementedError, "#{self.class} must implement #evaluate"
36
+ end
37
+
38
+ def calculate_result(metrics)
39
+ metrics.symbolize_keys!
40
+ metrics[metric.to_sym]
41
+ end
42
+
43
+ module ClassMethods
44
+ def self.extended(base)
45
+ class << base
46
+ attr_accessor :registry
47
+ end
48
+ end
49
+ end
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,126 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module ClassificationEvaluators
5
+ class AccuracyScore
6
+ include BaseEvaluator
7
+
8
+ def evaluate(y_pred:, y_true:, x_true: nil)
9
+ y_pred = Numo::Int32.cast(y_pred)
10
+ y_true = Numo::Int32.cast(y_true)
11
+ y_pred.eq(y_true).count_true.to_f / y_pred.size
12
+ end
13
+
14
+ def direction
15
+ "maximize"
16
+ end
17
+ end
18
+
19
+ class PrecisionScore
20
+ include BaseEvaluator
21
+
22
+ def evaluate(y_pred:, y_true:, x_true: nil)
23
+ y_pred = Numo::Int32.cast(y_pred)
24
+ y_true = Numo::Int32.cast(y_true)
25
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
26
+ predicted_positives = y_pred.eq(1).count_true
27
+ return 0 if predicted_positives.zero?
28
+
29
+ true_positives.to_f / predicted_positives
30
+ end
31
+
32
+ def direction
33
+ "maximize"
34
+ end
35
+ end
36
+
37
+ class RecallScore
38
+ include BaseEvaluator
39
+
40
+ def evaluate(y_pred:, y_true:, x_true: nil)
41
+ y_pred = Numo::Int32.cast(y_pred)
42
+ y_true = Numo::Int32.cast(y_true)
43
+ true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
44
+ actual_positives = y_true.eq(1).count_true
45
+ true_positives.to_f / actual_positives
46
+ end
47
+
48
+ def direction
49
+ "maximize"
50
+ end
51
+ end
52
+
53
+ class F1Score
54
+ include BaseEvaluator
55
+
56
+ def evaluate(y_pred:, y_true:, x_true: nil)
57
+ precision = PrecisionScore.new.evaluate(y_pred: y_pred, y_true: y_true)
58
+ recall = RecallScore.new.evaluate(y_pred: y_pred, y_true: y_true)
59
+ return 0 unless (precision + recall) > 0
60
+
61
+ 2 * (precision * recall) / (precision + recall)
62
+ end
63
+
64
+ def direction
65
+ "maximize"
66
+ end
67
+ end
68
+
69
+ class AUC
70
+ include BaseEvaluator
71
+
72
+ def evaluate(y_pred:, y_true:, x_true: nil)
73
+ y_pred = Numo::DFloat.cast(y_pred)
74
+ y_true = Numo::Int32.cast(y_true)
75
+
76
+ sorted_indices = y_pred.sort_index
77
+ y_pred[sorted_indices]
78
+ y_true_sorted = y_true[sorted_indices]
79
+
80
+ true_positive_rate = []
81
+ false_positive_rate = []
82
+
83
+ positive_count = y_true_sorted.eq(1).count_true
84
+ negative_count = y_true_sorted.eq(0).count_true
85
+
86
+ tp = 0
87
+ fp = 0
88
+
89
+ y_true_sorted.each do |label|
90
+ if label == 1
91
+ tp += 1
92
+ else
93
+ fp += 1
94
+ end
95
+ true_positive_rate << tp.to_f / positive_count
96
+ false_positive_rate << fp.to_f / negative_count
97
+ end
98
+
99
+ # Compute the AUC using the trapezoidal rule
100
+ tpr = Numo::DFloat[*true_positive_rate]
101
+ fpr = Numo::DFloat[*false_positive_rate]
102
+
103
+ auc = ((fpr[1..-1] - fpr[0...-1]) * (tpr[1..-1] + tpr[0...-1]) / 2.0).sum
104
+ auc
105
+ end
106
+
107
+ def direction
108
+ "maximize"
109
+ end
110
+ end
111
+
112
+ class ROC_AUC
113
+ include BaseEvaluator
114
+
115
+ def evaluate(y_pred:, y_true:, x_true: nil)
116
+ AUC.new.evaluate(y_pred: y_pred, y_true: y_true)
117
+ end
118
+
119
+ def direction
120
+ "maximize"
121
+ end
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,66 @@
1
+ module EasyML
2
+ module Core
3
+ module Evaluators
4
+ module RegressionEvaluators
5
+ class MeanAbsoluteError
6
+ include BaseEvaluator
7
+
8
+ def evaluate(y_pred:, y_true:, x_true: nil)
9
+ (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
10
+ end
11
+
12
+ def direction
13
+ "minimize"
14
+ end
15
+ end
16
+
17
+ class MeanSquaredError
18
+ include BaseEvaluator
19
+
20
+ def evaluate(y_pred:, y_true:, x_true: nil)
21
+ ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean
22
+ end
23
+
24
+ def direction
25
+ "minimize"
26
+ end
27
+ end
28
+
29
+ class RootMeanSquaredError
30
+ include BaseEvaluator
31
+
32
+ def evaluate(y_pred:, y_true:, x_true: nil)
33
+ Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)) ** 2).mean)
34
+ end
35
+
36
+ def direction
37
+ "minimize"
38
+ end
39
+ end
40
+
41
+ class R2Score
42
+ include BaseEvaluator
43
+
44
+ def direction
45
+ "maximize"
46
+ end
47
+
48
+ def evaluate(y_pred:, y_true:, x_true: nil)
49
+ y_true = Numo::DFloat.cast(y_true)
50
+ y_pred = Numo::DFloat.cast(y_pred)
51
+
52
+ mean_y = y_true.mean
53
+ ss_tot = ((y_true - mean_y) ** 2).sum
54
+ ss_res = ((y_true - y_pred) ** 2).sum
55
+
56
+ if ss_tot.zero?
57
+ ss_res.zero? ? 1.0 : Float::NAN
58
+ else
59
+ 1 - (ss_res / ss_tot)
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
65
+ end
66
+ end
@@ -1,78 +1,86 @@
1
+ require "numo/narray"
2
+ require_relative "evaluators/base_evaluator"
3
+ require_relative "evaluators/regression_evaluators"
4
+ require_relative "evaluators/classification_evaluators"
5
+
1
6
  module EasyML
2
7
  module Core
3
8
  class ModelEvaluator
4
- require "numo/narray"
5
-
6
- EVALUATORS = {
7
- mean_absolute_error: lambda { |y_pred, y_true|
8
- (Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true)).abs.mean
9
- },
10
- mean_squared_error: lambda { |y_pred, y_true|
11
- ((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean
12
- },
13
- root_mean_squared_error: lambda { |y_pred, y_true|
14
- Math.sqrt(((Numo::DFloat.cast(y_pred) - Numo::DFloat.cast(y_true))**2).mean)
15
- },
16
- r2_score: lambda { |y_pred, y_true|
17
- # Convert inputs to Numo::DFloat for numerical operations
18
- y_true = Numo::DFloat.cast(y_true)
19
- y_pred = Numo::DFloat.cast(y_pred)
20
-
21
- # Calculate the mean of the true values
22
- mean_y = y_true.mean
23
-
24
- # Calculate Total Sum of Squares (SS_tot)
25
- ss_tot = ((y_true - mean_y)**2).sum
26
-
27
- # Calculate Residual Sum of Squares (SS_res)
28
- ss_res = ((y_true - y_pred)**2).sum
29
-
30
- # Handle the edge case where SS_tot is zero
31
- if ss_tot.zero?
32
- if ss_res.zero?
33
- # Perfect prediction when both SS_tot and SS_res are zero
34
- 1.0
35
- else
36
- # Undefined R² when SS_tot is zero but SS_res is not
37
- Float::NAN
9
+ class << self
10
+ def callbacks=(callback)
11
+ @callbacks ||= []
12
+ @callbacks.push(callback)
13
+ end
14
+
15
+ def callbacks
16
+ @callbacks || []
17
+ end
18
+
19
+ def register(metric_name, evaluator, type, aliases = {})
20
+ @registry ||= {}
21
+ unless evaluator.included_modules.include?(Evaluators::BaseEvaluator)
22
+ evaluator.include(Evaluators::BaseEvaluator)
23
+ end
24
+
25
+ callbacks.each do |callback|
26
+ callback.call(metric_name)
27
+ end
28
+
29
+ @registry[metric_name.to_sym] = {
30
+ evaluator: evaluator,
31
+ type: type,
32
+ aliases: (aliases || []).map(&:to_sym),
33
+ }
34
+ end
35
+
36
+ def get(name)
37
+ return if name.nil?
38
+
39
+ @registry ||= {}
40
+ option = (@registry[name.to_sym] || @registry.detect do |_k, opts|
41
+ opts[:aliases].include?(name.to_sym)
42
+ end.last) || {}
43
+ option.dig(:evaluator)
44
+ end
45
+
46
+ def for_frontend(evaluator)
47
+ evaluator.new.to_h
48
+ end
49
+
50
+ def default_evaluator(task)
51
+ {
52
+ classification: {
53
+ metric: "accuracy_score",
54
+ threshold: 0.70,
55
+ direction: "maximize",
56
+ },
57
+ regression: {
58
+ metric: "root_mean_squared_error",
59
+ threshold: 10,
60
+ direction: "minimize",
61
+ },
62
+ }[task.to_sym]
63
+ end
64
+
65
+ def metrics_by_task
66
+ @registry.group_by { |_key, metric| metric[:type] }.transform_values do |group|
67
+ group.flat_map do |metric|
68
+ for_frontend(metric.last.dig(:evaluator))
38
69
  end
70
+ end
71
+ end
72
+
73
+ def metrics(task = nil)
74
+ if task.nil?
75
+ @registry.keys
39
76
  else
40
- # Calculate
41
- 1 - (ss_res / ss_tot)
77
+ @registry.select do |_k, v|
78
+ v[:type].to_sym == task.to_sym
79
+ end.keys
42
80
  end
43
- },
44
- accuracy_score: lambda { |y_pred, y_true|
45
- y_pred = Numo::Int32.cast(y_pred)
46
- y_true = Numo::Int32.cast(y_true)
47
- y_pred.eq(y_true).count_true.to_f / y_pred.size
48
- },
49
- precision_score: lambda { |y_pred, y_true|
50
- y_pred = Numo::Int32.cast(y_pred)
51
- y_true = Numo::Int32.cast(y_true)
52
- true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
53
- predicted_positives = y_pred.eq(1).count_true
54
- return 0 if predicted_positives == 0
55
-
56
- true_positives.to_f / predicted_positives
57
- },
58
- recall_score: lambda { |y_pred, y_true|
59
- y_pred = Numo::Int32.cast(y_pred)
60
- y_true = Numo::Int32.cast(y_true)
61
- true_positives = (y_pred.eq(1) & y_true.eq(1)).count_true
62
- actual_positives = y_true.eq(1).count_true
63
- true_positives.to_f / actual_positives
64
- },
65
- f1_score: lambda { |y_pred, y_true|
66
- precision = EVALUATORS[:precision_score].call(y_pred, y_true) || 0
67
- recall = EVALUATORS[:recall_score].call(y_pred, y_true) || 0
68
- return 0 unless (precision + recall) > 0
69
-
70
- 2 * (precision * recall) / (precision + recall)
71
- }
72
- }
81
+ end
73
82
 
74
- class << self
75
- def evaluate(model: nil, y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
83
+ def evaluate(model:, y_pred:, y_true:, x_true: nil, evaluator: nil)
76
84
  y_pred = normalize_input(y_pred)
77
85
  y_true = normalize_input(y_true)
78
86
  check_size(y_pred, y_true)
@@ -80,45 +88,46 @@ module EasyML
80
88
  metrics_results = {}
81
89
 
82
90
  model.metrics.each do |metric|
83
- if metric.is_a?(Module) || metric.is_a?(Class)
84
- unless metric.respond_to?(:evaluate)
85
- raise "Metric #{metric} must respond to #evaluate in order to be used as a custom evaluator"
86
- end
87
-
88
- metrics_results[metric.name] = metric.evaluate(y_pred, y_true)
89
- elsif EVALUATORS.key?(metric.to_sym)
90
- metrics_results[metric.to_sym] =
91
- EVALUATORS[metric.to_sym].call(y_pred, y_true)
92
- end
91
+ evaluator_class = get(metric.to_sym)
92
+ next unless evaluator_class
93
+
94
+ evaluator_instance = evaluator_class.new
95
+
96
+ metrics_results[metric.to_sym] = evaluator_instance.evaluate(
97
+ y_pred: y_pred,
98
+ y_true: y_true,
99
+ x_true: x_true,
100
+ )
93
101
  end
94
102
 
95
103
  if evaluator.present?
96
- if evaluator.is_a?(Class)
97
- response = evaluator.new.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
98
- elsif evaluator.respond_to?(:evaluate)
99
- response = evaluator.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
100
- elsif evaluator.respond_to?(:call)
101
- response = evaluator.call(y_pred: y_pred, y_true: y_true, x_true: x_true)
102
- else
103
- raise "Don't know how to use CustomEvaluator. Must be a class that responds to evaluate or lambda"
104
- end
104
+ evaluator = evaluator.symbolize_keys!
105
+ evaluator_class = get(evaluator[:metric])
106
+ raise "Unknown evaluator: #{evaluator}" unless evaluator_class
107
+
108
+ evaluator_instance = evaluator_class.new
109
+ response = evaluator_instance.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
105
110
 
106
111
  if response.is_a?(Hash)
107
112
  metrics_results.merge!(response)
108
113
  else
109
- metrics_results[:custom] = response
114
+ metrics_results[evaluator[:metric].to_sym] = response
110
115
  end
111
116
  end
112
117
 
113
- metrics_results
118
+ metrics_results.symbolize_keys
114
119
  end
115
120
 
121
+ private
122
+
116
123
  def check_size(y_pred, y_true)
117
124
  raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
118
125
  end
119
126
 
120
127
  def normalize_input(input)
121
128
  case input
129
+ when Array
130
+ Numo::DFloat.cast(input)
122
131
  when Polars::DataFrame
123
132
  if input.columns.count > 1
124
133
  raise ArgumentError, "Don't know how to evaluate input with multiple columns: #{input}"
@@ -135,3 +144,66 @@ module EasyML
135
144
  end
136
145
  end
137
146
  end
147
+
148
+ # Register default evaluators
149
+ EasyML::Core::ModelEvaluator.register(
150
+ :mean_absolute_error,
151
+ EasyML::Core::Evaluators::RegressionEvaluators::MeanAbsoluteError,
152
+ :regression,
153
+ %w[mae]
154
+ )
155
+ EasyML::Core::ModelEvaluator.register(
156
+ :mean_squared_error,
157
+ EasyML::Core::Evaluators::RegressionEvaluators::MeanSquaredError,
158
+ :regression,
159
+ %w[mse]
160
+ )
161
+ EasyML::Core::ModelEvaluator.register(
162
+ :root_mean_squared_error,
163
+ EasyML::Core::Evaluators::RegressionEvaluators::RootMeanSquaredError,
164
+ :regression,
165
+ %w[rmse]
166
+ )
167
+
168
+ EasyML::Core::ModelEvaluator.register(
169
+ :r2_score,
170
+ EasyML::Core::Evaluators::RegressionEvaluators::R2Score,
171
+ :regression,
172
+ %w[r2]
173
+ )
174
+ EasyML::Core::ModelEvaluator.register(
175
+ :accuracy_score,
176
+ EasyML::Core::Evaluators::ClassificationEvaluators::AccuracyScore,
177
+ :classification,
178
+ %w[accuracy]
179
+ )
180
+ EasyML::Core::ModelEvaluator.register(
181
+ :precision_score,
182
+ EasyML::Core::Evaluators::ClassificationEvaluators::PrecisionScore,
183
+ :classification,
184
+ %w[precision]
185
+ )
186
+ EasyML::Core::ModelEvaluator.register(
187
+ :recall_score,
188
+ EasyML::Core::Evaluators::ClassificationEvaluators::RecallScore,
189
+ :classification,
190
+ %w[recall]
191
+ )
192
+ EasyML::Core::ModelEvaluator.register(
193
+ :f1_score,
194
+ EasyML::Core::Evaluators::ClassificationEvaluators::F1Score,
195
+ :classification,
196
+ %w[f1]
197
+ )
198
+ # EasyML::Core::ModelEvaluator.register(
199
+ # :auc,
200
+ # EasyML::Core::Evaluators::ClassificationEvaluators::AUC,
201
+ # :classification,
202
+ # %w[auc]
203
+ # )
204
+ # EasyML::Core::ModelEvaluator.register(
205
+ # :roc_auc,
206
+ # EasyML::Core::Evaluators::ClassificationEvaluators::ROC_AUC,
207
+ # :classification,
208
+ # %w[roc_auc]
209
+ # )
@@ -3,39 +3,43 @@ module EasyML
3
3
  class Tuner
4
4
  module Adapters
5
5
  class BaseAdapter
6
- include GlueGun::DSL
6
+ attr_accessor :config, :project_name, :tune_started_at, :model,
7
+ :x_true, :y_true, :metadata, :model
8
+
9
+ def initialize(options = {})
10
+ @model = options[:model]
11
+ @config = options[:config] || {}
12
+ @project_name = options[:project_name]
13
+ @tune_started_at = options[:tune_started_at]
14
+ @model = options[:model]
15
+ @x_true = options[:x_true]
16
+ @y_true = options[:y_true]
17
+ @metadata = options[:metadata] || {}
18
+ end
7
19
 
8
20
  def defaults
9
21
  {}
10
22
  end
11
23
 
12
- attribute :model
13
- attribute :config, :hash
14
- attribute :project_name, :string
15
- attribute :tune_started_at
16
- attribute :y_true
17
- attribute :x_true
18
-
19
24
  def run_trial(trial)
20
- config = deep_merge_defaults(self.config.clone)
25
+ config = deep_merge_defaults(self.config.clone.deep_symbolize_keys)
21
26
  suggest_parameters(trial, config)
22
- model.fit
23
27
  yield model
24
28
  end
25
29
 
26
- def configure_callbacks
27
- raise "Subclasses fof Tuner::Adapter::BaseAdapter must define #configure_callbacks"
28
- end
29
-
30
30
  def suggest_parameters(trial, config)
31
- defaults.keys.each do |param_name|
32
- param_value = suggest_parameter(trial, param_name, config)
33
- model.hyperparameters.send("#{param_name}=", param_value)
31
+ config.keys.inject({}) do |hash, param_name|
32
+ hash.tap do
33
+ param_value = suggest_parameter(trial, param_name, config)
34
+ puts "Suggesting #{param_name}: #{param_value}"
35
+ model.hyperparameters.send("#{param_name}=", param_value)
36
+ hash[param_name] = param_value
37
+ end
34
38
  end
35
39
  end
36
40
 
37
41
  def deep_merge_defaults(config)
38
- defaults.deep_merge(config) do |_key, default_value, config_value|
42
+ defaults.deep_symbolize_keys.deep_merge(config.deep_symbolize_keys) do |_key, default_value, config_value|
39
43
  if default_value.is_a?(Hash) && config_value.is_a?(Hash)
40
44
  default_value.merge(config_value)
41
45
  else
@@ -46,12 +50,18 @@ module EasyML
46
50
 
47
51
  def suggest_parameter(trial, param_name, config)
48
52
  param_config = config[param_name]
53
+ if !param_config.is_a?(Hash)
54
+ return param_config
55
+ end
56
+
49
57
  min = param_config[:min]
50
58
  max = param_config[:max]
51
59
  log = param_config[:log]
52
60
 
53
61
  if log
54
62
  trial.suggest_loguniform(param_name.to_s, min, max)
63
+ elsif max.is_a?(Integer) && min.is_a?(Integer)
64
+ trial.suggest_int(param_name.to_s, min, max)
55
65
  else
56
66
  trial.suggest_uniform(param_name.to_s, min, max)
57
67
  end