easy_ml 0.1.4 → 0.2.0.pre.rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -5,44 +5,23 @@ module EasyML
5
5
  class Tuner
6
6
  module Adapters
7
7
  class XGBoostAdapter < BaseAdapter
8
- include GlueGun::DSL
9
-
10
8
  def defaults
11
9
  {
12
10
  learning_rate: {
13
11
  min: 0.001,
14
12
  max: 0.1,
15
- log: true
13
+ log: true,
16
14
  },
17
15
  n_estimators: {
18
16
  min: 100,
19
- max: 1_000
17
+ max: 1_000,
20
18
  },
21
19
  max_depth: {
22
20
  min: 2,
23
- max: 20
24
- }
21
+ max: 20,
22
+ },
25
23
  }
26
24
  end
27
-
28
- def configure_callbacks
29
- model.customize_callbacks do |callbacks|
30
- return unless callbacks.present?
31
-
32
- wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
33
- return unless wandb_callback.present?
34
-
35
- wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
36
- wandb_callback.custom_loggers = [
37
- lambda do |booster, _epoch, _hist|
38
- dtrain = model.send(:preprocess, x_true, y_true)
39
- y_pred = booster.predict(dtrain)
40
- metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
41
- Wandb.log(metrics)
42
- end
43
- ]
44
- end
45
- end
46
25
  end
47
26
  end
48
27
  end
@@ -4,34 +4,36 @@ require_relative "tuner/adapters"
4
4
  module EasyML
5
5
  module Core
6
6
  class Tuner
7
- include GlueGun::DSL
8
-
9
- attribute :model
10
- attribute :dataset
11
- attribute :project_name, :string
12
- attribute :task, :string
13
- attribute :config, :hash, default: {}
14
- attribute :metrics, :array
15
- attribute :objective, :string
16
- attribute :n_trials, default: 100
17
- attribute :callbacks, :array
18
- attr_accessor :study, :results
19
-
20
- dependency :adapter, lazy: false do |dep|
21
- dep.option :xgboost do |opt|
22
- opt.set_class Adapters::XGBoostAdapter
23
- opt.bind_attribute :model
24
- opt.bind_attribute :config
25
- opt.bind_attribute :project_name
26
- opt.bind_attribute :tune_started_at
27
- opt.bind_attribute :y_true
28
- end
7
+ attr_accessor :model, :dataset, :project_name, :task, :config,
8
+ :metrics, :objective, :n_trials, :direction, :evaluator,
9
+ :study, :results, :adapter, :tune_started_at, :x_true, :y_true,
10
+ :project_name, :job, :current_run
11
+
12
+ def initialize(options = {})
13
+ @model = options[:model]
14
+ @dataset = options[:dataset]
15
+ @project_name = options[:project_name]
16
+ @task = options[:task]
17
+ @config = options[:config] || {}
18
+ @metrics = options[:metrics]
19
+ @objective = options[:objective]
20
+ @n_trials = options[:n_trials] || 100
21
+ @direction = EasyML::Core::ModelEvaluator.get(objective).new.direction
22
+ @evaluator = options[:evaluator]
23
+ @tune_started_at = EasyML::Support::UTC.now
24
+ @project_name = "#{@model.name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
25
+ end
29
26
 
30
- dep.when do |_dep|
31
- case model
32
- when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
33
- { option: :xgboost }
34
- end
27
+ def initialize_adapter
28
+ case model&.model_type
29
+ when "xgboost"
30
+ Adapters::XGBoostAdapter.new(
31
+ model: model,
32
+ config: config,
33
+ project_name: project_name,
34
+ tune_started_at: nil, # This will be set during tune
35
+ y_true: nil, # This will be set during tune
36
+ )
35
37
  end
36
38
  end
37
39
 
@@ -41,53 +43,112 @@ module EasyML
41
43
  raise "Trial failed: Stopping optimization."
42
44
  end
43
45
 
44
- def tune
45
- set_defaults!
46
+ def wandb_enabled?
47
+ EasyML::Configuration.wandb_api_key.present?
48
+ end
46
49
 
47
- @study = Optuna::Study.new
50
+ def tune(&progress_block)
51
+ set_defaults!
52
+ @adapter = initialize_adapter
53
+
54
+ tuner_params = {
55
+ model: model,
56
+ config: {
57
+ n_trials: n_trials,
58
+ objective: objective,
59
+ hyperparameter_ranges: config,
60
+ },
61
+ direction: direction,
62
+ status: :running,
63
+ started_at: Time.current,
64
+ wandb_url: wandb_enabled? ? "https://wandb.ai/fundera/#{@project_name}" : nil,
65
+ }.compact
66
+
67
+ tuner_job = EasyML::TunerJob.create!(tuner_params)
68
+ @job = tuner_job
69
+ @study = Optuna::Study.new(direction: direction)
48
70
  @results = []
71
+ model.evaluator = evaluator if evaluator.present?
49
72
  model.task = task
73
+
74
+ model.dataset.refresh
50
75
  x_true, y_true = model.dataset.test(split_ys: true)
51
- tune_started_at = EST.now
52
- adapter = pick_adapter.new(model: model, config: config, tune_started_at: tune_started_at, y_true: y_true,
53
- x_true: x_true)
54
- adapter.configure_callbacks
55
-
56
- @study.optimize(n_trials: n_trials, callbacks: [method(:loggers)]) do |trial|
57
- run_metrics = tune_once(trial, x_true, y_true, adapter)
58
-
59
- result = if model.evaluator.present?
60
- if model.evaluator_metric.present?
61
- run_metrics[model.evaluator_metric]
62
- else
63
- run_metrics[:custom]
64
- end
65
- else
66
- run_metrics[objective.to_sym]
67
- end
68
- @results.push(result)
69
- result
70
- rescue StandardError => e
71
- puts "Optuna failed with: #{e.message}"
76
+ self.x_true = x_true
77
+ self.y_true = y_true
78
+ adapter.tune_started_at = tune_started_at
79
+ adapter.y_true = y_true
80
+ adapter.x_true = x_true
81
+
82
+ model.prepare_data unless model.batch_mode
83
+ model.prepare_callbacks(self)
84
+
85
+ n_trials.times do |run_number|
86
+ trial = @study.ask
87
+ puts "Running trial #{trial.number}"
88
+ @tuner_run = tuner_job.tuner_runs.new(
89
+ trial_number: trial.number,
90
+ status: :running,
91
+ )
92
+
93
+ self.current_run = @tuner_run
94
+
95
+ begin
96
+ run_metrics = tune_once(trial, x_true, y_true, adapter, &progress_block)
97
+ result = calculate_result(run_metrics)
98
+ @results.push(result)
99
+
100
+ params = {
101
+ hyperparameters: model.hyperparameters.to_h,
102
+ value: result,
103
+ status: :success,
104
+ }.compact
105
+
106
+ @tuner_run.update!(params)
107
+ @study.tell(trial, result)
108
+ rescue StandardError => e
109
+ @tuner_run.update!(status: :failed, hyperparameters: {})
110
+ puts "Optuna failed with: #{e.message}"
111
+ raise e
112
+ end
72
113
  end
73
114
 
74
- raise "Optuna study failed" unless @study.respond_to?(:best_trial)
75
-
76
- @study.best_trial.params
115
+ model.after_tuning
116
+ return nil if tuner_job.tuner_runs.all?(&:failed?)
117
+
118
+ best_run = tuner_job.best_run
119
+ tuner_job.update!(
120
+ metadata: adapter.metadata,
121
+ best_tuner_run_id: best_run&.id,
122
+ status: :success,
123
+ completed_at: Time.current,
124
+ )
125
+
126
+ best_run&.hyperparameters
127
+ rescue StandardError => e
128
+ tuner_job&.update!(status: :failed, completed_at: Time.current)
129
+ raise e
77
130
  end
78
131
 
79
- def pick_adapter
80
- case model
81
- when EasyML::Core::Models::XGBoost, EasyML::Models::XGBoost
82
- Adapters::XGBoostAdapter
132
+ private
133
+
134
+ def calculate_result(run_metrics)
135
+ run_metrics.symbolize_keys!
136
+
137
+ if model.evaluator.present?
138
+ run_metrics[model.evaluator[:metric].to_sym]
139
+ else
140
+ run_metrics[objective.to_sym]
83
141
  end
84
142
  end
85
143
 
86
- def tune_once(trial, x_true, y_true, adapter)
144
+ def tune_once(trial, x_true, y_true, adapter, &progress_block)
87
145
  adapter.run_trial(trial) do |model|
88
- y_pred = model.predict(y_true)
146
+ model.fit(tuning: true, &progress_block)
147
+ y_pred = model.predict(x_true)
89
148
  model.metrics = metrics
90
- model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
149
+ metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
150
+ puts metrics
151
+ metrics
91
152
  end
92
153
  end
93
154
 
@@ -98,7 +159,7 @@ module EasyML
98
159
  end
99
160
  raise ArgumentError, "Objectives required for EasyML::Core::Tuner" unless objective.present?
100
161
 
101
- self.metrics = EasyML::Core::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
162
+ self.metrics = EasyML::Model.new(task: task).allowed_metrics if metrics.nil? || metrics.empty?
102
163
  end
103
164
  end
104
165
  end
data/lib/easy_ml/core.rb CHANGED
@@ -1,8 +1,5 @@
1
1
  module EasyML
2
2
  module Core
3
- require_relative "core/uploaders"
4
- require_relative "core/model"
5
- require_relative "core/models"
6
3
  require_relative "core/model_evaluator"
7
4
  require_relative "core/tuner"
8
5
  end
@@ -0,0 +1,24 @@
1
+ module EasyML
2
+ module CoreExt
3
+ module Hash
4
+ def deep_compact
5
+ each_with_object({}) do |(key, value), result|
6
+ next if value.nil?
7
+
8
+ compacted = if value.is_a?(Hash)
9
+ value.deep_compact
10
+ elsif value.is_a?(Array)
11
+ value.map { |v| v.is_a?(Hash) ? v.deep_compact : v }.compact
12
+ else
13
+ value
14
+ end
15
+
16
+ result[key] = compacted unless compacted.blank?
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
22
+
23
+ # Extend Hash class with our custom method
24
+ Hash.include EasyML::CoreExt::Hash
@@ -1,9 +1,15 @@
1
1
  require "pathname"
2
2
 
3
- class Pathname
4
- def append(folder)
5
- dir = cleanpath
6
- dir = dir.join(folder) unless basename.to_s == folder
7
- dir
3
+ module EasyML
4
+ module CoreExt
5
+ module Pathname
6
+ def append(folder)
7
+ dir = cleanpath
8
+ dir = dir.join(folder) unless basename.to_s == folder
9
+ dir
10
+ end
11
+ end
8
12
  end
9
13
  end
14
+
15
+ Pathname.include EasyML::CoreExt::Pathname
@@ -0,0 +1,90 @@
1
+ module EasyML
2
+ module Data
3
+ module DateConverter
4
+ COMMON_DATE_FORMATS = [
5
+ "%Y-%m-%dT%H:%M:%S.%6N", # e.g., "2021-01-01T00:00:00.000000"
6
+ "%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
7
+ "%Y-%m-%d %H:%M:%S.%L", # e.g., "2021-01-01 00:01:36.000"
8
+ "%Y-%m-%d %H:%M:%S", # e.g., "2021-01-01 00:01:36"
9
+ "%Y-%m-%d %H:%M", # e.g., "2021-01-01 00:01"
10
+ "%Y-%m-%d", # e.g., "2021-01-01"
11
+ "%m/%d/%Y %H:%M:%S", # e.g., "01/01/2021 00:01:36"
12
+ "%m/%d/%Y", # e.g., "01/01/2021"
13
+ "%d-%m-%Y", # e.g., "01-01-2021"
14
+ "%d-%b-%Y %H:%M:%S", # e.g., "01-Jan-2021 00:01:36"
15
+ "%d-%b-%Y", # e.g., "01-Jan-2021"
16
+ "%b %d, %Y", # e.g., "Jan 01, 2021"
17
+ "%Y/%m/%d %H:%M:%S", # e.g., "2021/01/01 00:01:36"
18
+ "%Y/%m/%d", # e.g., "2021/01/01"
19
+ ].freeze
20
+
21
+ FORMAT_MAPPINGS = {
22
+ ruby_to_polars: {
23
+ "%L" => "%3f", # milliseconds
24
+ "%6N" => "%6f", # microseconds
25
+ "%N" => "%9f", # nanoseconds
26
+ },
27
+ }.freeze
28
+
29
+ class << self
30
+ # Attempts to convert a string column to datetime if it appears to be a date
31
+ # @param df [Polars::DataFrame] The dataframe containing the series
32
+ # @param column [String] The name of the column to convert
33
+ # @return [Polars::DataFrame] The dataframe with converted column (if successful)
34
+ def maybe_convert_date(df, column = nil)
35
+ if column.nil?
36
+ series = df
37
+ column = series.name
38
+ df = Polars::DataFrame.new(series)
39
+ else
40
+ series = df[column]
41
+ end
42
+ return df if series.dtype.is_a?(Polars::Datetime)
43
+ return df unless series.dtype == Polars::Utf8
44
+
45
+ format = detect_polars_format(series)
46
+ return df unless format
47
+
48
+ df.with_column(
49
+ Polars.col(column.to_s).str.strptime(Polars::Datetime, format).alias(column.to_s)
50
+ )
51
+ end
52
+
53
+ private
54
+
55
+ def detect_polars_format(series)
56
+ return nil unless series.is_a?(Polars::Series)
57
+
58
+ sample = series.filter(series.is_not_null).head(100).to_a
59
+ ruby_format = detect_date_format(sample)
60
+ convert_format(:ruby_to_polars, ruby_format)
61
+ end
62
+
63
+ def detect_date_format(date_strings)
64
+ return nil if date_strings.empty?
65
+
66
+ sample = date_strings.compact.sample([100, date_strings.length].min)
67
+
68
+ COMMON_DATE_FORMATS.detect do |format|
69
+ sample.all? do |date_str|
70
+ DateTime.strptime(date_str, format)
71
+ true
72
+ rescue StandardError
73
+ false
74
+ end
75
+ end
76
+ end
77
+
78
+ def convert_format(conversion, format)
79
+ return nil if format.nil?
80
+
81
+ result = format.dup
82
+ FORMAT_MAPPINGS[conversion].each do |from, to|
83
+ result = result.gsub(from, to)
84
+ end
85
+ result
86
+ end
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,31 @@
1
+ module EasyML
2
+ module Data
3
+ module FilterExtensions
4
+ def is_primary_key_filter?(primary_key)
5
+ return false unless primary_key
6
+ primary_key = [primary_key] unless primary_key.is_a?(Array)
7
+ # Filter expressions in Polars are represented as strings like:
8
+ # [([(col("LOAN_APP_ID")) > (dyn int: 4)]) & ([(col("LOAN_APP_ID")) < (dyn int: 16)])]
9
+ expr_str = to_s
10
+ return false unless expr_str.include?(primary_key.first)
11
+
12
+ # Check for common primary key operations
13
+ primary_key_ops = [">", "<", ">=", "<=", "=", "eq", "gt", "lt", "ge", "le"]
14
+ primary_key_ops.any? { |op| expr_str.include?(op) }
15
+ end
16
+
17
+ def extract_primary_key_values
18
+ expr_str = to_s
19
+ # Extract numeric values from the expression
20
+ # This will match both integers and floats
21
+ values = expr_str.scan(/(?:dyn int|float): (-?\d+(?:\.\d+)?)/).flatten.map(&:to_f)
22
+ values.uniq
23
+ end
24
+ end
25
+ end
26
+ end
27
+
28
+ # Extend Polars classes with our filter functionality
29
+ [Polars::Expr].each do |klass|
30
+ klass.include(EasyML::Data::FilterExtensions)
31
+ end
@@ -0,0 +1,126 @@
1
+ require_relative "date_converter"
2
+
3
+ module EasyML
4
+ module Data
5
+ module PolarsColumn
6
+ TYPE_MAP = {
7
+ float: Polars::Float64,
8
+ integer: Polars::Int64,
9
+ boolean: Polars::Boolean,
10
+ datetime: Polars::Datetime,
11
+ string: Polars::String,
12
+ text: Polars::String,
13
+ categorical: Polars::Categorical
14
+ }
15
+ POLARS_MAP = TYPE_MAP.invert.stringify_keys
16
+ class << self
17
+ def polars_to_sym(polars_type)
18
+ POLARS_MAP.dig(polars_type.class.to_s)
19
+ end
20
+
21
+ def sym_to_polars(symbol)
22
+ TYPE_MAP.dig(symbol)
23
+ end
24
+
25
+ # Determines the semantic type of a field based on its data
26
+ # @param series [Polars::Series] The series to analyze
27
+ # @return [Symbol] One of :numeric, :datetime, :categorical, or :text
28
+ def determine_type(series, polars_type = false)
29
+ dtype = series.dtype
30
+
31
+ if dtype.is_a?(Polars::Utf8)
32
+ string_type = determine_string_type(series)
33
+ if string_type == :datetime
34
+ date = EasyML::Data::DateConverter.maybe_convert_date(series)
35
+ return polars_type ? date[date.columns.first].dtype : :datetime
36
+ end
37
+ end
38
+
39
+ type_name = case dtype
40
+ when Polars::Float64
41
+ :float
42
+ when Polars::Int64
43
+ :integer
44
+ when Polars::Datetime
45
+ :datetime
46
+ when Polars::Boolean
47
+ :boolean
48
+ when Polars::Utf8
49
+ determine_string_type(series)
50
+ else
51
+ :categorical
52
+ end
53
+
54
+ polars_type ? sym_to_polars(type_name) : type_name
55
+ end
56
+
57
+ # Determines if a string field is a date, text, or categorical
58
+ # @param series [Polars::Series] The string series to analyze
59
+ # @return [Symbol] One of :datetime, :text, or :categorical
60
+ def determine_string_type(series)
61
+ if EasyML::Data::DateConverter.maybe_convert_date(Polars::DataFrame.new({ temp: series }),
62
+ :temp)[:temp].dtype.is_a?(Polars::Datetime)
63
+ :datetime
64
+ else
65
+ categorical_or_text?(series)
66
+ end
67
+ end
68
+
69
+ # Determines if a string field is categorical or free text
70
+ # @param series [Polars::Series] The string series to analyze
71
+ # @return [Symbol] Either :categorical or :text
72
+ def categorical_or_text?(series)
73
+ return :categorical if series.null_count == series.len
74
+
75
+ # Get non-null count for percentage calculations
76
+ non_null_count = series.len - series.null_count
77
+ return :categorical if non_null_count == 0
78
+
79
+ # Get value counts as percentages
80
+ value_counts = series.value_counts(parallel: true)
81
+ percentages = value_counts.with_column(
82
+ (value_counts["count"] / non_null_count.to_f * 100).alias("percentage")
83
+ )
84
+
85
+ # Check if any category represents more than 10% of the data
86
+ max_percentage = percentages["percentage"].max
87
+ return :text if max_percentage < 10.0
88
+
89
+ # Calculate average percentage per category
90
+ avg_percentage = 100.0 / series.n_unique
91
+
92
+ # If average category represents less than 1% of data, it's likely text
93
+ avg_percentage < 1.0 ? :text : :categorical
94
+ end
95
+
96
+ # Returns whether the field type is numeric
97
+ # @param field_type [Symbol] The field type to check
98
+ # @return [Boolean]
99
+ def numeric?(field_type)
100
+ field_type == :numeric
101
+ end
102
+
103
+ # Returns whether the field type is categorical
104
+ # @param field_type [Symbol] The field type to check
105
+ # @return [Boolean]
106
+ def categorical?(field_type)
107
+ field_type == :categorical
108
+ end
109
+
110
+ # Returns whether the field type is datetime
111
+ # @param field_type [Symbol] The field type to check
112
+ # @return [Boolean]
113
+ def datetime?(field_type)
114
+ field_type == :datetime
115
+ end
116
+
117
+ # Returns whether the field type is text
118
+ # @param field_type [Symbol] The field type to check
119
+ # @return [Boolean]
120
+ def text?(field_type)
121
+ field_type == :text
122
+ end
123
+ end
124
+ end
125
+ end
126
+ end