easy_ml 0.1.4 → 0.2.0.pre.rc1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,68 +1,602 @@
1
- require_relative "../../../lib/easy_ml/core/model"
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_models
4
+ #
5
+ # id :bigint not null, primary key
6
+ # name :string not null
7
+ # model_type :string
8
+ # status :string
9
+ # dataset_id :bigint
10
+ # model_file_id :bigint
11
+ # configuration :json
12
+ # version :string not null
13
+ # root_dir :string
14
+ # file :json
15
+ # sha :string
16
+ # last_trained_at :datetime
17
+ # is_training :boolean
18
+ # created_at :datetime not null
19
+ # updated_at :datetime not null
20
+ #
21
+ require_relative "models/hyperparameters"
22
+
2
23
  module EasyML
3
24
  class Model < ActiveRecord::Base
4
- if ActiveRecord::Base.connection.data_source_exists?("easy_ml_models")
5
- include EasyML::Core::ModelCore
6
-
7
- self.table_name = "easy_ml_models"
8
- else
9
- # Placeholder if the table doesn't exist (keeps the file quiet)
10
- def self.table_ready?
11
- false
25
+ self.table_name = "easy_ml_models"
26
+ include Historiographer::Silent
27
+ historiographer_mode :snapshot_only
28
+
29
+ include EasyML::Concerns::Configurable
30
+ include EasyML::Concerns::Versionable
31
+
32
+ self.filter_attributes += [:configuration]
33
+
34
+ MODEL_OPTIONS = {
35
+ "xgboost" => "EasyML::Models::XGBoost",
36
+ }
37
+ MODEL_TYPES = [
38
+ {
39
+ value: "xgboost",
40
+ label: "XGBoost",
41
+ description: "Extreme Gradient Boosting, a scalable and accurate implementation of gradient boosting machines",
42
+ },
43
+ ].freeze
44
+ MODEL_NAMES = MODEL_OPTIONS.keys.freeze
45
+ MODEL_CONSTANTS = MODEL_OPTIONS.values.map(&:constantize)
46
+
47
+ add_configuration_attributes :task, :objective, :hyperparameters, :evaluator, :callbacks, :metrics
48
+ MODEL_CONSTANTS.flat_map(&:configuration_attributes).each do |attribute|
49
+ add_configuration_attributes attribute
50
+ end
51
+
52
+ belongs_to :dataset
53
+ belongs_to :model_file, class_name: "EasyML::ModelFile", foreign_key: "model_file_id", optional: true
54
+
55
+ has_one :retraining_job, class_name: "EasyML::RetrainingJob"
56
+ accepts_nested_attributes_for :retraining_job
57
+ has_many :retraining_runs, class_name: "EasyML::RetrainingRun"
58
+ has_many :deploys, class_name: "EasyML::Deploy"
59
+
60
+ scope :deployed, -> { EasyML::ModelHistory.deployed }
61
+
62
+ def latest_deploy
63
+ deploys.order(id: :desc).limit(1).last
64
+ end
65
+
66
+ after_initialize :bump_version, if: -> { new_record? }
67
+ after_initialize :set_defaults, if: -> { new_record? }
68
+ before_save :save_model_file, if: -> { is_fit? && !is_history_class? && model_changed? && !@skip_save_model_file }
69
+
70
+ VALID_TASKS = %i[regression classification].freeze
71
+
72
+ TASK_TYPES = [
73
+ {
74
+ value: "classification",
75
+ label: "Classification",
76
+ description: "Predict categorical outcomes or class labels",
77
+ },
78
+ {
79
+ value: "regression",
80
+ label: "Regression",
81
+ description: "Predict continuous numerical values",
82
+ },
83
+ ].freeze
84
+
85
+ validates :name, presence: true
86
+ validates :name, uniqueness: { case_sensitive: false }
87
+ validates :task, presence: true
88
+ validates :task, inclusion: {
89
+ in: VALID_TASKS.map { |t| [t, t.to_s] }.flatten,
90
+ message: "must be one of: #{VALID_TASKS.join(", ")}",
91
+ }
92
+ validates :model_type, inclusion: { in: MODEL_NAMES }
93
+ validates :dataset_id, presence: true
94
+ validate :validate_metrics_allowed
95
+ before_save :set_root_dir
96
+
97
+ delegate :prepare_data, :preprocess, to: :adapter
98
+
99
+ STATUSES = %w[development inference retired]
100
+ STATUSES.each do |status|
101
+ define_method "#{status}?" do
102
+ self.status.to_sym == status.to_sym
12
103
  end
104
+ end
13
105
 
14
- Rails.logger.info("Skipping EasyML::Model definition as the 'easy_ml_models' table doesn't exist.")
106
+ def training?
107
+ is_training == true
15
108
  end
16
109
 
17
- scope :live, -> { where(is_live: true) }
18
- attribute :root_dir, :string
19
- after_initialize :apply_defaults
110
+ def train(async: true)
111
+ pending_run # Ensure we update the pending job before enqueuing in background so UI updates properly
112
+ update(is_training: true)
113
+ if async
114
+ EasyML::TrainingJob.perform_later(id)
115
+ else
116
+ actually_train
117
+ end
118
+ end
20
119
 
21
- validate :only_one_model_is_live?
22
- def only_one_model_is_live?
23
- return if @marking_live
120
+ def get_retraining_job
121
+ if retraining_job
122
+ self.evaluator = retraining_job.evaluator
123
+ evaluator = self.evaluator.symbolize_keys
124
+ else
125
+ default_eval = Core::ModelEvaluator.default_evaluator(task)
126
+ self.evaluator = default_eval
127
+ evaluator = default_eval
128
+ end
129
+
130
+ retraining_job || create_retraining_job(
131
+ model: self,
132
+ active: false,
133
+ evaluator: evaluator,
134
+ metric: evaluator[:metric],
135
+ direction: evaluator[:direction],
136
+ threshold: evaluator[:threshold],
137
+ frequency: "month",
138
+ at: { hour: 0, day_of_month: 1 },
139
+ )
140
+ end
24
141
 
25
- if previous_versions.live.count > 1
26
- raise "Multiple previous versions of #{name} are live! This should never happen. Update previous versions to is_live=false before proceeding"
142
+ def pending_run
143
+ job = get_retraining_job
144
+ job.retraining_runs.find_or_create_by(status: "pending", model: self)
145
+ end
146
+
147
+ def actually_train(&progress_block)
148
+ lock_model do
149
+ run = pending_run
150
+ run.wrap_training do
151
+ best_params = nil
152
+ if run.should_tune?
153
+ best_params = hyperparameter_search(&progress_block)
154
+ end
155
+ fit(&progress_block)
156
+ save
157
+ [self, best_params]
158
+ end
159
+ update(is_training: false)
160
+ run.reload
161
+ ensure
162
+ unlock!
163
+ end
164
+ end
165
+
166
+ def unlock!
167
+ Support::Lockable.unlock!(lock_key)
168
+ end
169
+
170
+ def lock_model
171
+ with_lock do |client|
172
+ yield
173
+ end
174
+ end
175
+
176
+ def with_lock
177
+ EasyML::Support::Lockable.with_lock(lock_key, stale_timeout: 60, resources: 1) do |client|
178
+ yield client
179
+ end
180
+ end
181
+
182
+ def lock_key
183
+ "training:#{self.name}:#{self.id}"
184
+ end
185
+
186
+ def hyperparameter_search(&progress_block)
187
+ tuner = retraining_job.tuner_config.symbolize_keys
188
+ extra_params = {
189
+ evaluator: evaluator,
190
+ model: self,
191
+ dataset: dataset,
192
+ }.compact
193
+ tuner.merge!(extra_params)
194
+ tuner_instance = EasyML::Core::Tuner.new(tuner)
195
+ tuner_instance.tune(&progress_block).tap do |best_params|
196
+ best_params.each do |key, value|
197
+ self.hyperparameters.send("#{key}=", value)
198
+ end
27
199
  end
200
+ end
201
+
202
+ def deployment_status
203
+ status
204
+ end
205
+
206
+ def formatted_model_type
207
+ adapter.class.name.split("::").last
208
+ end
209
+
210
+ def formatted_version
211
+ return nil unless version
212
+ Time.strptime(version, "%Y%m%d%H%M%S").strftime("%B %-d, %Y at %-l:%M %p")
213
+ end
28
214
 
29
- return unless previous_versions.live.any? && is_live
215
+ def last_run_at
216
+ last_run&.created_at
217
+ end
218
+
219
+ def last_run
220
+ retraining_runs.order(id: :desc).limit(1).last
221
+ end
222
+
223
+ def inference_version
224
+ latest_deploy&.model_version
225
+ end
226
+
227
+ alias_method :current_version, :inference_version
228
+ alias_method :latest_version, :inference_version
229
+ alias_method :deployed, :inference_version
30
230
 
31
- errors.add(:is_live,
32
- "cannot mark model live when previous version is live. Explicitly use the mark_live method to mark this as the live version")
231
+ def hyperparameters
232
+ @hypers ||= adapter.build_hyperparameters(@hyperparameters)
33
233
  end
34
234
 
35
- def mark_live
36
- transaction do
37
- self.class.where(name: name).where.not(id: id).update_all(is_live: false)
38
- self.class.where(id: id).update_all(is_live: true)
235
+ def callbacks
236
+ @cbs ||= adapter.build_callbacks(@callbacks)
237
+ end
238
+
239
+ def predict(xs)
240
+ load_model!
241
+ adapter.predict(xs)
242
+ end
243
+
244
+ def save_model_file
245
+ raise "No trained model! Need to train model before saving (call model.fit)" unless is_fit?
246
+ return unless adapter.loaded?
247
+
248
+ model_file = get_model_file
249
+
250
+ bump_version(force: true)
251
+ path = model_file.full_path(version)
252
+ full_path = adapter.save_model_file(path)
253
+ model_file.upload(full_path)
254
+
255
+ model_file.save
256
+ self.model_file = model_file
257
+ cleanup
258
+ end
259
+
260
+ def feature_names
261
+ adapter.feature_names
262
+ end
263
+
264
+ def cleanup!
265
+ get_model_file&.cleanup!
266
+ end
267
+
268
+ def cleanup
269
+ get_model_file&.cleanup(files_to_keep)
270
+ end
271
+
272
+ def loaded?
273
+ model_file = get_model_file
274
+ return false if model_file.persisted? && !File.exist?(model_file.full_path.to_s)
275
+
276
+ file_exists = true
277
+ if model_file.present? && model_file.persisted? && model_file.full_path.present?
278
+ file_exists = File.exist?(model_file.full_path)
39
279
  end
280
+
281
+ loaded = adapter.loaded?
282
+ load_model_file unless loaded
283
+ file_exists && adapter.loaded?
40
284
  end
41
285
 
42
- def previous_versions
43
- EasyML::Model.where(name: name).order(id: :desc)
286
+ def model_changed?
287
+ return false unless is_fit?
288
+ return true if inference_version.nil?
289
+ return true if model_file.present? && !model_file.persisted?
290
+ return true if model_file.present? && model_file.fit? && inference_version.nil?
291
+
292
+ adapter.model_changed?(inference_version.sha)
293
+ end
294
+
295
+ def feature_importances
296
+ adapter.feature_importances
297
+ end
298
+
299
+ def fit_in_batches?
300
+ retraining_job.present? && retraining_job.batch_mode == true
301
+ end
302
+
303
+ def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
304
+ return fit_in_batches(**batch_args.merge!(tuning: tuning), &progress_block) if fit_in_batches?
305
+
306
+ dataset.refresh
307
+ adapter.fit(tuning: tuning, x_train: x_train, y_train: y_train, x_valid: x_valid, y_valid: y_valid, &progress_block)
308
+ @is_fit = true
309
+ end
310
+
311
+ def batch_args
312
+ defaults = {
313
+ batch_size: 1024,
314
+ batch_overlap: 3,
315
+ batch_key: nil,
316
+ }
317
+ overrides = { batch_size: retraining_job&.batch_size, batch_overlap: retraining_job&.batch_overlap, batch_key: retraining_job&.batch_key }.compact
318
+ defaults.merge!(overrides)
319
+ end
320
+
321
+ def batch_mode
322
+ retraining_job&.batch_mode || false
323
+ end
324
+
325
+ def prepare_callbacks(tune_started_at)
326
+ adapter.prepare_callbacks(tune_started_at)
327
+ end
328
+
329
+ def after_tuning
330
+ adapter.after_tuning
331
+ end
332
+
333
+ def fit_in_batches(tuning: false, batch_size: nil, batch_overlap: nil, batch_key: nil, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"), &progress_block)
334
+ adapter.fit_in_batches(tuning: tuning, batch_size: batch_size, batch_overlap: batch_overlap, batch_key: batch_key, checkpoint_dir: checkpoint_dir, &progress_block)
335
+ @is_fit = true
336
+ end
337
+
338
+ attr_accessor :is_fit
339
+
340
+ def is_fit?
341
+ model_file = get_model_file
342
+ return true if model_file.present? && model_file.fit?
343
+
344
+ adapter.is_fit?
345
+ end
346
+
347
+ def deployable?
348
+ cannot_deploy_reasons.none?
349
+ end
350
+
351
+ def decode_labels(ys, col: nil)
352
+ dataset.decode_labels(ys, col: col)
353
+ end
354
+
355
+ def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
356
+ evaluator ||= self.evaluator
357
+ if y_pred.nil?
358
+ inputs = default_evaluation_inputs
359
+ y_pred = inputs[:y_pred]
360
+ y_true = inputs[:y_true]
361
+ x_true = inputs[:x_true]
362
+ end
363
+ EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true, evaluator: evaluator)
364
+ end
365
+
366
+ def evaluator
367
+ instance_variable_get(:@evaluator) || default_evaluator
368
+ end
369
+
370
+ def default_evaluator
371
+ return nil unless task.present?
372
+
373
+ EasyML::Core::ModelEvaluator.default_evaluator(task)
374
+ end
375
+
376
+ def get_params
377
+ @hyperparameters.to_h
378
+ end
379
+
380
+ def evals
381
+ last_run&.metrics || {}
382
+ end
383
+
384
+ def metric_accessor(metric)
385
+ metrics = last_run.metrics.symbolize_keys
386
+ metrics.dig(metric.to_sym)
387
+ end
388
+
389
+ EasyML::Core::ModelEvaluator.metrics.each do |metric_name|
390
+ define_method metric_name do
391
+ metric_accessor(metric_name)
392
+ end
393
+ end
394
+
395
+ EasyML::Core::ModelEvaluator.callbacks = lambda do |metric_name|
396
+ EasyML::Model.define_method metric_name do
397
+ metric_accessor(metric_name)
398
+ end
399
+ end
400
+
401
+ def allowed_metrics
402
+ EasyML::Core::ModelEvaluator.metrics(task).map(&:to_s)
403
+ end
404
+
405
+ def default_metrics
406
+ return [] unless task.present?
407
+
408
+ case task.to_sym
409
+ when :regression
410
+ %w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
411
+ when :classification
412
+ %w[accuracy_score precision_score recall_score f1_score]
413
+ else
414
+ []
415
+ end
416
+ end
417
+
418
+ def self.constants
419
+ {
420
+ objectives: objectives_by_model_type,
421
+ metrics: metrics_by_task,
422
+ tasks: TASK_TYPES,
423
+ timezone: EasyML::Configuration.timezone_label,
424
+ retraining_job_constants: EasyML::RetrainingJob.constants,
425
+ tuner_job_constants: EasyML::TunerJob.constants,
426
+ }
427
+ end
428
+
429
+ def self.metrics_by_task
430
+ EasyML::Core::ModelEvaluator.metrics_by_task
431
+ end
432
+
433
+ def self.objectives_by_model_type
434
+ MODEL_OPTIONS.inject({}) do |h, (k, v)|
435
+ h.tap do
436
+ h[k] = v.constantize.const_get(:OBJECTIVES_FRONTEND)
437
+ end
438
+ end.deep_symbolize_keys
439
+ end
440
+
441
+ def attributes
442
+ super.merge!(
443
+ hyperparameters: hyperparameters.to_h,
444
+ )
445
+ end
446
+
447
+ class CannotdeployError < StandardError
448
+ end
449
+
450
+ def deploy(async: true)
451
+ last_run.deploy(async: async)
452
+ end
453
+
454
+ def actually_deploy
455
+ raise CannotdeployError, cannot_deploy_reasons.first if cannot_deploy_reasons.any?
456
+
457
+ # Prepare the inference model by freezing + saving the model, dataset, and datasource
458
+ # (This creates ModelHistory, DatasetHistory, etc)
459
+ save_model_file
460
+ self.sha = model_file.sha
461
+ save
462
+ dataset.upload_remote_files
463
+ snapshot.tap do
464
+ # Prepare the model to be retrained (reset values so they don't conflict with our snapshotted version)
465
+ bump_version(force: true)
466
+ dataset.bump_versions(version)
467
+ self.model_file = new_model_file!
468
+ save
469
+ end
470
+ end
471
+
472
+ def cannot_deploy_reasons
473
+ [
474
+ is_fit? ? nil : "Model has not been trained",
475
+ dataset.target.present? ? nil : "Dataset has no target",
476
+ !dataset.datasource.in_memory? ? nil : "Cannot perform inference using an in-memory datasource",
477
+ ].compact
478
+ end
479
+
480
+ def root_dir=(value)
481
+ raise "Cannot override value of root_dir!" unless value.to_s == root_dir.to_s
482
+
483
+ write_attribute(:root_dir, value)
484
+ end
485
+
486
+ def set_root_dir
487
+ write_attribute(:root_dir, root_dir)
488
+ end
489
+
490
+ def root_dir
491
+ EasyML::Engine.root_dir.join("models").join(underscored_name).to_s
492
+ end
493
+
494
+ def load_model(force: false)
495
+ download_model_file(force: force)
496
+ load_model_file
497
+ end
498
+
499
+ def metrics=(value)
500
+ value = [value] unless value.is_a?(Array)
501
+ value = value.map(&:to_s)
502
+ value = value.uniq
503
+ @metrics = value
504
+ end
505
+
506
+ def adapter
507
+ @adapter ||= begin
508
+ adapter_class = MODEL_OPTIONS[model_type]
509
+ raise "Don't know how to use model adapter #{model_type}!" unless adapter_class.present?
510
+
511
+ adapter_class.constantize.new(self)
512
+ end
44
513
  end
45
514
 
46
515
  private
47
516
 
517
+ def default_evaluation_inputs
518
+ x_true, y_true = dataset.test(split_ys: true)
519
+ y_pred = predict(x_true)
520
+ {
521
+ x_true: x_true,
522
+ y_true: y_true,
523
+ y_pred: y_pred,
524
+ }
525
+ end
526
+
527
+ def underscored_name
528
+ name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
529
+ end
530
+
531
+ def get_model_file
532
+ model_file || new_model_file!
533
+ end
534
+
535
+ def new_model_file!
536
+ build_model_file(
537
+ root_dir: root_dir,
538
+ model: self,
539
+ s3_bucket: EasyML::Configuration.s3_bucket,
540
+ s3_region: EasyML::Configuration.s3_region,
541
+ s3_access_key_id: EasyML::Configuration.s3_access_key_id,
542
+ s3_secret_access_key: EasyML::Configuration.s3_secret_access_key,
543
+ s3_prefix: prefix,
544
+ )
545
+ end
546
+
547
+ def prefix
548
+ s3_prefix = EasyML::Configuration.s3_prefix
549
+ s3_prefix.present? ? File.join(s3_prefix, name) : name
550
+ end
551
+
552
+ def load_model!
553
+ load_model(force: true)
554
+ load_dataset
555
+ end
556
+
557
+ def load_dataset
558
+ dataset.load_dataset
559
+ end
560
+
561
+ def load_model_file
562
+ return unless model_file&.full_path && File.exist?(model_file.full_path)
563
+
564
+ adapter.load_model_file(model_file.full_path)
565
+ end
566
+
567
+ def download_model_file(force: false)
568
+ return unless persisted?
569
+ return if loaded? && !force
570
+
571
+ get_model_file.download
572
+ end
573
+
48
574
  def files_to_keep
49
- live_models = self.class.live
575
+ inference_models = EasyML::ModelHistory.deployed
576
+ training_models = EasyML::Model.all
50
577
 
51
- recent_copies = live_models.flat_map do |live|
52
- # Fetch all models with the same name
53
- self.class.where(name: live.name).where(is_live: false).order(created_at: :desc).limit(live.name == name ? 4 : 5)
54
- end
578
+ ([self] + training_models + inference_models).compact.map(&:model_file).compact.map(&:full_path).uniq
579
+ end
55
580
 
56
- recent_versions = self.class
57
- .where.not(
58
- "EXISTS (SELECT 1 FROM easy_ml_models e2 WHERE e2.name = easy_ml_models.name AND e2.is_live = true)"
59
- )
60
- .where("created_at >= ?", 2.days.ago)
61
- .order(created_at: :desc)
62
- .group_by(&:name)
63
- .flat_map { |_, models| models.take(5) }
581
+ def underscored_name
582
+ name = self.name || self.class.name.split("::").last
583
+ name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
584
+ end
64
585
 
65
- ([self] + recent_versions + recent_copies + live_models).compact.map(&:file).map(&:path).uniq
586
+ def set_defaults
587
+ self.model_type ||= "xgboost"
588
+ self.status ||= :training
589
+ self.metrics ||= default_metrics
590
+ end
591
+
592
+ def validate_metrics_allowed
593
+ unknown_metrics = metrics.select { |metric| allowed_metrics.exclude?(metric) }
594
+ return unless unknown_metrics.any?
595
+
596
+ errors.add(:metrics,
597
+ "don't know how to handle #{"metrics".pluralize(unknown_metrics)} #{unknown_metrics.join(", ")}, use EasyML::Core::ModelEvaluator.register(:name, Evaluator, :regression|:classification)")
66
598
  end
67
599
  end
68
600
  end
601
+
602
+ require_relative "models/xgboost"