easy_ml 0.1.4 → 0.2.0.pre.rc1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -5,44 +5,23 @@ module EasyML
5
5
  class Tuner
6
6
  module Adapters
7
7
  class XGBoostAdapter < BaseAdapter
8
- include GlueGun::DSL
9
-
10
8
  def defaults
11
9
  {
12
10
  learning_rate: {
13
11
  min: 0.001,
14
12
  max: 0.1,
15
- log: true
13
+ log: true,
16
14
  },
17
15
  n_estimators: {
18
16
  min: 100,
19
- max: 1_000
17
+ max: 1_000,
20
18
  },
21
19
  max_depth: {
22
20
  min: 2,
23
- max: 20
24
- }
21
+ max: 20,
22
+ },
25
23
  }
26
24
  end
27
-
28
- def configure_callbacks
29
- model.customize_callbacks do |callbacks|
30
- return unless callbacks.present?
31
-
32
- wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
33
- return unless wandb_callback.present?
34
-
35
- wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
36
- wandb_callback.custom_loggers = [
37
- lambda do |booster, _epoch, _hist|
38
- dtrain = model.send(:preprocess, x_true, y_true)
39
- y_pred = booster.predict(dtrain)
40
- metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
41
- Wandb.log(metrics)
42
- end
43
- ]
44
- end
45
- end
46
25
  end
47
26
  end
48
27
  end
@@ -4,34 +4,36 @@ require_relative "tuner/adapters"
4
4
  module EasyML
5
5
  module Core
6
6
  class Tuner
7
- include GlueGun::DSL
8
-
9
- attribute :model
10
- attribute :dataset
11
- attribute :project_name, :string
12
- attribute :task, :string
13
- attribute :config, :hash, default: {}
14
- attribute :metrics, :array
15
- attribute :objective, :string
16
- attribute :n_trials, default: 100
17
- attribute :callbacks, :array
18
- attr_accessor :study, :results
19
-
20
- dependency :adapter, lazy: false do |dep|
21
- dep.option :xgboost do |opt|
22
- opt.set_class Adapters::XGBoostAdapter
23
- opt.bind_attribute :model
24
- opt.bind_attribute :config
25
- opt.bind_attribute :project_name
26
- opt.bind_attribute :tune_started_at
27
- opt.bind_attribute :y_true
28
- end
7
+ attr_accessor :model, :dataset, :project_name, :task, :config,
8
+ :metrics, :objective, :n_trials, :direction, :evaluator,
9
+ :study, :results, :adapter, :tune_started_at, :x_true, :y_true,
10
+ :project_name, :job, :current_run
11
+
12
+ def initialize(options = {})
13
+ @model = options[:model]
14
+ @dataset = options[:dataset]
15
+ @project_name = options[:project_name]
16
+ @task = options[:task]
17
+ @config = options[:config] || {}
18
+ @metrics = options[:metrics]
19
+ @objective = options[:objective]
20
+ @n_trials = options[:n_trials] || 100
21
+ @direction = EasyML::Core::ModelEvaluator.get(objective).new.direction
22
+ @evaluator = options[:evaluator]
23
+ @tune_started_at = EasyML::Support::UTC.now
24
+ @project_name = "#{@model.name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
25
+ end
29
26
 
30
- dep.when do |_dep|
31
- case model
32
- when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
33
- { option: :xgboost }
34
- end
27
+ def initialize_adapter
28
+ case model&.model_type
29
+ when "xgboost"
30
+ Adapters::XGBoostAdapter.new(
31
+ model: model,
32
+ config: config,
33
+ project_name: project_name,
34
+ tune_started_at: nil, # This will be set during tune
35
+ y_true: nil, # This will be set during tune
36
+ )
35
37
  end
36
38
  end
37
39
 
@@ -41,53 +43,112 @@ module EasyML
41
43
  raise "Trial failed: Stopping optimization."
42
44
  end
43
45
 
44
- def tune
45
- set_defaults!
46
+ def wandb_enabled?
47
+ EasyML::Configuration.wandb_api_key.present?
48
+ end
46
49
 
47
- @study = Optuna::Study.new
50
+ def tune(&progress_block)
51
+ set_defaults!
52
+ @adapter = initialize_adapter
53
+
54
+ tuner_params = {
55
+ model: model,
56
+ config: {
57
+ n_trials: n_trials,
58
+ objective: objective,
59
+ hyperparameter_ranges: config,
60
+ },
61
+ direction: direction,
62
+ status: :running,
63
+ started_at: Time.current,
64
+ wandb_url: wandb_enabled? ? "https://wandb.ai/fundera/#{@project_name}" : nil,
65
+ }.compact
66
+
67
+ tuner_job = EasyML::TunerJob.create!(tuner_params)
68
+ @job = tuner_job
69
+ @study = Optuna::Study.new(direction: direction)
48
70
  @results = []
71
+ model.evaluator = evaluator if evaluator.present?
49
72
  model.task = task
73
+
74
+ model.dataset.refresh
50
75
  x_true, y_true = model.dataset.test(split_ys: true)
51
- tune_started_at = EST.now
52
- adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
53
- x_true: x_true)
54
- adapter.configure_callbacks
55
-
56
- @study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
57
- run_metrics = tune_once(trial, x_true, y_true, adapter)
58
-
59
- result = if model.evaluator.present?
60
- if model.evaluator_metric.present?
61
- run_metrics[model.evaluator_metric]
62
- else
63
- run_metrics[:custom]
64
- end
65
- else
66
- run_metrics[objective.to_sym]
67
- end
68
- @results.push(result)
69
- result
70
- rescue StandardError => e
71
- puts "Optuna failed with: #{e.message}"
76
+ self.x_true = x_true
77
+ self.y_true = y_true
78
+ adapter.tune_started_at = tune_started_at
79
+ adapter.y_true = y_true
80
+ adapter.x_true = x_true
81
+
82
+ model.prepare_data unless model.batch_mode
83
+ model.prepare_callbacks(self)
84
+
85
+ n_trials.times do |run_number|
86
+ trial = @study.ask
87
+ puts "Running trial #{trial.number}"
88
+ @tuner_run = tuner_job.tuner_runs.new(
89
+ trial_number: trial.number,
90
+ status: :running,
91
+ )
92
+
93
+ self.current_run = @tuner_run
94
+
95
+ begin
96
+ run_metrics = tune_once(trial, x_true, y_true, adapter, &progress_block)
97
+ result = calculate_result(run_metrics)
98
+ @results.push(result)
99
+
100
+ params = {
101
+ hyperparameters: model.hyperparameters.to_h,
102
+ value: result,
103
+ status: :success,
104
+ }.compact
105
+
106
+ @tuner_run.update!(params)
107
+ @study.tell(trial, result)
108
+ rescue StandardError => e
109
+ @tuner_run.update!(status: :failed, hyperparameters: {})
110
+ puts "Optuna failed with: #{e.message}"
111
+ raise e
112
+ end
72
113
  end
73
114
 
74
- raise "Optuna study failed" unless @study.respond_to?(:best_trial)
75
-
76
- @study.best_trial.params
115
+ model.after_tuning
116
+ return nil if tuner_job.tuner_runs.all?(&:failed?)
117
+
118
+ best_run = tuner_job.best_run
119
+ tuner_job.update!(
120
+ metadata: adapter.metadata,
121
+ best_tuner_run_id: best_run&.id,
122
+ status: :success,
123
+ completed_at: Time.current,
124
+ )
125
+
126
+ best_run&.hyperparameters
127
+ rescue StandardError => e
128
+ tuner_job&.update!(status: :failed, completed_at: Time.current)
129
+ raise e
77
130
  end
78
131
 
79
- def pick_adapter
80
- case model
81
- when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
82
- Adapters::XGBoostAdapter
132
+ private
133
+
134
+ def calculate_result(run_metrics)
135
+ run_metrics.symbolize_keys!
136
+
137
+ if model.evaluator.present?
138
+ run_metrics[model.evaluator[:metric].to_sym]
139
+ else
140
+ run_metrics[objective.to_sym]
83
141
  end
84
142
  end
85
143
 
86
- def tune_once(trial, x_true, y_true, adapter)
144
+ def tune_once(trial, x_true, y_true, adapter, &progress_block)
87
145
  adapter.run_trial(trial) do |model|
88
- y_pred = model.predict(y_true)
146
+ model.fit(tuning: true, &progress_block)
147
+ y_pred = model.predict(x_true)
89
148
  model.metrics = metrics
90
- model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
149
+ metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
150
+ puts metrics
151
+ metrics
91
152
  end
92
153
  end
93
154
 
@@ -98,7 +159,7 @@ module EasyML
98
159
  end
99
160
  raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
100
161
 
101
- self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
162
+ self.metrics = EasyML::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
102
163
  end
103
164
  end
104
165
  end
data/lib/easy_ml/core.rb CHANGED
@@ -1,8 +1,5 @@
1
1
  module EasyML
2
2
  module Core
3
- require_relative "core/uploaders"
4
- require_relative "core/model"
5
- require_relative "core/models"
6
3
  require_relative "core/model_evaluator"
7
4
  require_relative "core/tuner"
8
5
  end
@@ -0,0 +1,24 @@
1
+ module EasyML
2
+ module CoreExt
3
+ module Hash
4
+ def deep_compact
5
+ each_with_object({}) do |(key, value), result|
6
+ next if value.nil?
7
+
8
+ compacted = if value.is_a?(Hash)
9
+ value.deep_compact
10
+ elsif value.is_a?(Array)
11
+ value.map { |v| v.is_a?(Hash) ? v.deep_compact : v }.compact
12
+ else
13
+ value
14
+ end
15
+
16
+ result[key] = compacted unless compacted.blank?
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+
23
+ # Extend Hash class with our custom method
24
+ Hash.include EasyML::CoreExt::Hash
@@ -1,9 +1,15 @@
1
1
  require "pathname"
2
2
 
3
- class Pathname
4
- def append(folder)
5
- dir = cleanpath
6
- dir = dir.join(folder) unless basename.to_s == folder
7
- dir
3
+ module EasyML
4
+ module CoreExt
5
+ module Pathname
6
+ def append(folder)
7
+ dir = cleanpath
8
+ dir = dir.join(folder) unless basename.to_s == folder
9
+ dir
10
+ end
11
+ end
8
12
  end
9
13
  end
14
+
15
+ Pathname.include EasyML::CoreExt::Pathname
@@ -0,0 +1,90 @@
1
+ module EasyML
2
+ module Data
3
+ module DateConverter
4
+ COMMON_DATE_FORMATS = [
5
+ "%Y-%m-%dT%H:%M:%S.%6N", # e.g., "2021-01-01T00:00:00.000000"
6
+ "%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
7
+ "%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
8
+ "%Y-%m-%d %H:%M:%S", # e.g., "2021-01-01 00:01:36"
9
+ "%Y-%m-%d %H:%M", # e.g., "2021-01-01 00:01"
10
+ "%Y-%m-%d", # e.g., "2021-01-01"
11
+ "%m/%d/%Y %H:%M:%S", # e.g., "01/01/2021 00:01:36"
12
+ "%m/%d/%Y", # e.g., "01/01/2021"
13
+ "%d-%m-%Y", # e.g., "01-01-2021"
14
+ "%d-%b-%Y %H:%M:%S", # e.g., "01-Jan-2021 00:01:36"
15
+ "%d-%b-%Y", # e.g., "01-Jan-2021"
16
+ "%b %d, %Y", # e.g., "Jan 01, 2021"
17
+ "%Y/%m/%d %H:%M:%S", # e.g., "2021/01/01 00:01:36"
18
+ "%Y/%m/%d", # e.g., "2021/01/01"
19
+ ].freeze
20
+
21
+ FORMAT_MAPPINGS = {
22
+ ruby_to_polars: {
23
+ "%L" => "%3f", # milliseconds
24
+ "%6N" => "%6f", # microseconds
25
+ "%N" => "%9f", # nanoseconds
26
+ },
27
+ }.freeze
28
+
29
+ class << self
30
+ # Attempts to convert a string column to datetime if it appears to be a date
31
+ # @param df [Polars::DataFrame] The dataframe containing the series
32
+ # @param column [String] The name of the column to convert
33
+ # @return [Polars::DataFrame] The dataframe with converted column (if successful)
34
+ def maybe_convert_date(df, column = nil)
35
+ if column.nil?
36
+ series = df
37
+ column = series.name
38
+ df = Polars::DataFrame.new(series)
39
+ else
40
+ series = df[column]
41
+ end
42
+ return df if series.dtype.is_a?(Polars::Datetime)
43
+ return df unless series.dtype == Polars::Utf8
44
+
45
+ format = detect_polars_format(series)
46
+ return df unless format
47
+
48
+ df.with_column(
49
+ Polars.col(column.to_s).str.strptime(Polars::Datetime, format).alias(column.to_s)
50
+ )
51
+ end
52
+
53
+ private
54
+
55
+ def detect_polars_format(series)
56
+ return nil unless series.is_a?(Polars::Series)
57
+
58
+ sample = series.filter(series.is_not_null).head(100).to_a
59
+ ruby_format = detect_date_format(sample)
60
+ convert_format(:ruby_to_polars, ruby_format)
61
+ end
62
+
63
+ def detect_date_format(date_strings)
64
+ return nil if date_strings.empty?
65
+
66
+ sample = date_strings.compact.sample([100, date_strings.length].min)
67
+
68
+ COMMON_DATE_FORMATS.detect do |format|
69
+ sample.all? do |date_str|
70
+ DateTime.strptime(date_str, format)
71
+ true
72
+ rescue StandardError
73
+ false
74
+ end
75
+ end
76
+ end
77
+
78
+ def convert_format(conversion, format)
79
+ return nil if format.nil?
80
+
81
+ result = format.dup
82
+ FORMAT_MAPPINGS[conversion].each do |from, to|
83
+ result = result.gsub(from, to)
84
+ end
85
+ result
86
+ end
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,31 @@
1
+ module EasyML
2
+ module Data
3
+ module FilterExtensions
4
+ def is_primary_key_filter?(primary_key)
5
+ return false unless primary_key
6
+ primary_key = [primary_key] unless primary_key.is_a?(Array)
7
+ # Filter expressions in Polars are represented as strings like:
8
+ # [([(col("LOAN_APP_ID")) > (dyn int: 4)]) & ([(col("LOAN_APP_ID")) < (dyn int: 16)])]
9
+ expr_str = to_s
10
+ return false unless expr_str.include?(primary_key.first)
11
+
12
+ # Check for common primary key operations
13
+ primary_key_ops = [">", "<", ">=", "<=", "=", "eq", "gt", "lt", "ge", "le"]
14
+ primary_key_ops.any? { |op| expr_str.include?(op) }
15
+ end
16
+
17
+ def extract_primary_key_values
18
+ expr_str = to_s
19
+ # Extract numeric values from the expression
20
+ # This will match both integers and floats
21
+ values = expr_str.scan(/(?:dyn int|float): (-?\d+(?:\.\d+)?)/).flatten.map(&:to_f)
22
+ values.uniq
23
+ end
24
+ end
25
+ end
26
+ end
27
+
28
+ # Extend Polars classes with our filter functionality
29
+ [Polars::Expr].each do |klass|
30
+ klass.include(EasyML::Data::FilterExtensions)
31
+ end
@@ -0,0 +1,126 @@
1
+ require_relative "date_converter"
2
+
3
+ module EasyML
4
+ module Data
5
+ module PolarsColumn
6
+ TYPE_MAP = {
7
+ float: Polars::Float64,
8
+ integer: Polars::Int64,
9
+ boolean: Polars::Boolean,
10
+ datetime: Polars::Datetime,
11
+ string: Polars::String,
12
+ text: Polars::String,
13
+ categorical: Polars::Categorical
14
+ }
15
+ POLARS_MAP = TYPE_MAP.invert.stringify_keys
16
+ class << self
17
+ def polars_to_sym(polars_type)
18
+ POLARS_MAP.dig(polars_type.class.to_s)
19
+ end
20
+
21
+ def sym_to_polars(symbol)
22
+ TYPE_MAP.dig(symbol)
23
+ end
24
+
25
+ # Determines the semantic type of a field based on its data
26
+ # @param series [Polars::Series] The series to analyze
27
+ # @return [Symbol] One of :numeric, :datetime, :categorical, or :text
28
+ def determine_type(series, polars_type = false)
29
+ dtype = series.dtype
30
+
31
+ if dtype.is_a?(Polars::Utf8)
32
+ string_type = determine_string_type(series)
33
+ if string_type == :datetime
34
+ date = EasyML::Data::DateConverter.maybe_convert_date(series)
35
+ return polars_type ? date[date.columns.first].dtype : :datetime
36
+ end
37
+ end
38
+
39
+ type_name = case dtype
40
+ when Polars::Float64
41
+ :float
42
+ when Polars::Int64
43
+ :integer
44
+ when Polars::Datetime
45
+ :datetime
46
+ when Polars::Boolean
47
+ :boolean
48
+ when Polars::Utf8
49
+ determine_string_type(series)
50
+ else
51
+ :categorical
52
+ end
53
+
54
+ polars_type ? sym_to_polars(type_name) : type_name
55
+ end
56
+
57
+ # Determines if a string field is a date, text, or categorical
58
+ # @param series [Polars::Series] The string series to analyze
59
+ # @return [Symbol] One of :datetime, :text, or :categorical
60
+ def determine_string_type(series)
61
+ if EasyML::Data::DateConverter.maybe_convert_date(Polars::DataFrame.new({ temp: series }),
62
+ :temp)[:temp].dtype.is_a?(Polars::Datetime)
63
+ :datetime
64
+ else
65
+ categorical_or_text?(series)
66
+ end
67
+ end
68
+
69
+ # Determines if a string field is categorical or free text
70
+ # @param series [Polars::Series] The string series to analyze
71
+ # @return [Symbol] Either :categorical or :text
72
+ def categorical_or_text?(series)
73
+ return :categorical if series.null_count == series.len
74
+
75
+ # Get non-null count for percentage calculations
76
+ non_null_count = series.len - series.null_count
77
+ return :categorical if non_null_count == 0
78
+
79
+ # Get value counts as percentages
80
+ value_counts = series.value_counts(parallel: true)
81
+ percentages = value_counts.with_column(
82
+ (value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
83
+ )
84
+
85
+ # Check if any category represents more than 10% of the data
86
+ max_percentage = percentages["percentage"].max
87
+ return :text if max_percentage < 10.0
88
+
89
+ # Calculate average percentage per category
90
+ avg_percentage = 100.0 / series.n_unique
91
+
92
+ # If average category represents less than 1% of data, it's likely text
93
+ avg_percentage < 1.0 ? :text : :categorical
94
+ end
95
+
96
+ # Returns whether the field type is numeric
97
+ # @param field_type [Symbol] The field type to check
98
+ # @return [Boolean]
99
+ def numeric?(field_type)
100
+ field_type == :numeric
101
+ end
102
+
103
+ # Returns whether the field type is categorical
104
+ # @param field_type [Symbol] The field type to check
105
+ # @return [Boolean]
106
+ def categorical?(field_type)
107
+ field_type == :categorical
108
+ end
109
+
110
+ # Returns whether the field type is datetime
111
+ # @param field_type [Symbol] The field type to check
112
+ # @return [Boolean]
113
+ def datetime?(field_type)
114
+ field_type == :datetime
115
+ end
116
+
117
+ # Returns whether the field type is text
118
+ # @param field_type [Symbol] The field type to check
119
+ # @return [Boolean]
120
+ def text?(field_type)
121
+ field_type == :text
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end