easy_ml 0.1.4 → 0.2.0.pre.rc1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (239) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +234 -26
  3. data/Rakefile +45 -0
  4. data/app/controllers/easy_ml/application_controller.rb +67 -0
  5. data/app/controllers/easy_ml/columns_controller.rb +38 -0
  6. data/app/controllers/easy_ml/datasets_controller.rb +156 -0
  7. data/app/controllers/easy_ml/datasources_controller.rb +88 -0
  8. data/app/controllers/easy_ml/deploys_controller.rb +20 -0
  9. data/app/controllers/easy_ml/models_controller.rb +151 -0
  10. data/app/controllers/easy_ml/retraining_runs_controller.rb +19 -0
  11. data/app/controllers/easy_ml/settings_controller.rb +59 -0
  12. data/app/frontend/components/AlertProvider.tsx +108 -0
  13. data/app/frontend/components/DatasetPreview.tsx +161 -0
  14. data/app/frontend/components/EmptyState.tsx +28 -0
  15. data/app/frontend/components/ModelCard.tsx +255 -0
  16. data/app/frontend/components/ModelDetails.tsx +334 -0
  17. data/app/frontend/components/ModelForm.tsx +384 -0
  18. data/app/frontend/components/Navigation.tsx +300 -0
  19. data/app/frontend/components/Pagination.tsx +72 -0
  20. data/app/frontend/components/Popover.tsx +55 -0
  21. data/app/frontend/components/PredictionStream.tsx +105 -0
  22. data/app/frontend/components/ScheduleModal.tsx +726 -0
  23. data/app/frontend/components/SearchInput.tsx +23 -0
  24. data/app/frontend/components/SearchableSelect.tsx +132 -0
  25. data/app/frontend/components/dataset/AutosaveIndicator.tsx +39 -0
  26. data/app/frontend/components/dataset/ColumnConfigModal.tsx +431 -0
  27. data/app/frontend/components/dataset/ColumnFilters.tsx +256 -0
  28. data/app/frontend/components/dataset/ColumnList.tsx +101 -0
  29. data/app/frontend/components/dataset/FeatureConfigPopover.tsx +57 -0
  30. data/app/frontend/components/dataset/FeaturePicker.tsx +205 -0
  31. data/app/frontend/components/dataset/PreprocessingConfig.tsx +704 -0
  32. data/app/frontend/components/dataset/SplitConfigurator.tsx +120 -0
  33. data/app/frontend/components/dataset/splitters/DateSplitter.tsx +58 -0
  34. data/app/frontend/components/dataset/splitters/KFoldSplitter.tsx +68 -0
  35. data/app/frontend/components/dataset/splitters/LeavePOutSplitter.tsx +29 -0
  36. data/app/frontend/components/dataset/splitters/PredefinedSplitter.tsx +146 -0
  37. data/app/frontend/components/dataset/splitters/RandomSplitter.tsx +85 -0
  38. data/app/frontend/components/dataset/splitters/StratifiedSplitter.tsx +79 -0
  39. data/app/frontend/components/dataset/splitters/constants.ts +77 -0
  40. data/app/frontend/components/dataset/splitters/types.ts +168 -0
  41. data/app/frontend/components/dataset/splitters/utils.ts +53 -0
  42. data/app/frontend/components/features/CodeEditor.tsx +46 -0
  43. data/app/frontend/components/features/DataPreview.tsx +150 -0
  44. data/app/frontend/components/features/FeatureCard.tsx +88 -0
  45. data/app/frontend/components/features/FeatureForm.tsx +235 -0
  46. data/app/frontend/components/features/FeatureGroupCard.tsx +54 -0
  47. data/app/frontend/components/settings/PluginSettings.tsx +81 -0
  48. data/app/frontend/components/ui/badge.tsx +44 -0
  49. data/app/frontend/components/ui/collapsible.tsx +9 -0
  50. data/app/frontend/components/ui/scroll-area.tsx +46 -0
  51. data/app/frontend/components/ui/separator.tsx +29 -0
  52. data/app/frontend/entrypoints/App.tsx +40 -0
  53. data/app/frontend/entrypoints/Application.tsx +24 -0
  54. data/app/frontend/hooks/useAutosave.ts +61 -0
  55. data/app/frontend/layouts/Layout.tsx +38 -0
  56. data/app/frontend/lib/utils.ts +6 -0
  57. data/app/frontend/mockData.ts +272 -0
  58. data/app/frontend/pages/DatasetDetailsPage.tsx +103 -0
  59. data/app/frontend/pages/DatasetsPage.tsx +261 -0
  60. data/app/frontend/pages/DatasourceFormPage.tsx +147 -0
  61. data/app/frontend/pages/DatasourcesPage.tsx +261 -0
  62. data/app/frontend/pages/EditModelPage.tsx +45 -0
  63. data/app/frontend/pages/EditTransformationPage.tsx +56 -0
  64. data/app/frontend/pages/ModelsPage.tsx +115 -0
  65. data/app/frontend/pages/NewDatasetPage.tsx +366 -0
  66. data/app/frontend/pages/NewModelPage.tsx +45 -0
  67. data/app/frontend/pages/NewTransformationPage.tsx +43 -0
  68. data/app/frontend/pages/SettingsPage.tsx +272 -0
  69. data/app/frontend/pages/ShowModelPage.tsx +30 -0
  70. data/app/frontend/pages/TransformationsPage.tsx +95 -0
  71. data/app/frontend/styles/application.css +100 -0
  72. data/app/frontend/types/dataset.ts +146 -0
  73. data/app/frontend/types/datasource.ts +33 -0
  74. data/app/frontend/types/preprocessing.ts +1 -0
  75. data/app/frontend/types.ts +113 -0
  76. data/app/helpers/easy_ml/application_helper.rb +10 -0
  77. data/app/jobs/easy_ml/application_job.rb +21 -0
  78. data/app/jobs/easy_ml/batch_job.rb +46 -0
  79. data/app/jobs/easy_ml/compute_feature_job.rb +19 -0
  80. data/app/jobs/easy_ml/deploy_job.rb +13 -0
  81. data/app/jobs/easy_ml/finalize_feature_job.rb +15 -0
  82. data/app/jobs/easy_ml/refresh_dataset_job.rb +32 -0
  83. data/app/jobs/easy_ml/schedule_retraining_job.rb +11 -0
  84. data/app/jobs/easy_ml/sync_datasource_job.rb +17 -0
  85. data/app/jobs/easy_ml/training_job.rb +62 -0
  86. data/app/models/easy_ml/adapters/base_adapter.rb +45 -0
  87. data/app/models/easy_ml/adapters/polars_adapter.rb +77 -0
  88. data/app/models/easy_ml/cleaner.rb +82 -0
  89. data/app/models/easy_ml/column.rb +124 -0
  90. data/app/models/easy_ml/column_history.rb +30 -0
  91. data/app/models/easy_ml/column_list.rb +122 -0
  92. data/app/models/easy_ml/concerns/configurable.rb +61 -0
  93. data/app/models/easy_ml/concerns/versionable.rb +19 -0
  94. data/app/models/easy_ml/dataset.rb +767 -0
  95. data/app/models/easy_ml/dataset_history.rb +56 -0
  96. data/app/models/easy_ml/datasource.rb +182 -0
  97. data/app/models/easy_ml/datasource_history.rb +24 -0
  98. data/app/models/easy_ml/datasources/base_datasource.rb +54 -0
  99. data/app/models/easy_ml/datasources/file_datasource.rb +58 -0
  100. data/app/models/easy_ml/datasources/polars_datasource.rb +89 -0
  101. data/app/models/easy_ml/datasources/s3_datasource.rb +97 -0
  102. data/app/models/easy_ml/deploy.rb +114 -0
  103. data/app/models/easy_ml/event.rb +79 -0
  104. data/app/models/easy_ml/feature.rb +437 -0
  105. data/app/models/easy_ml/feature_history.rb +38 -0
  106. data/app/models/easy_ml/model.rb +575 -41
  107. data/app/models/easy_ml/model_file.rb +133 -0
  108. data/app/models/easy_ml/model_file_history.rb +24 -0
  109. data/app/models/easy_ml/model_history.rb +51 -0
  110. data/app/models/easy_ml/models/base_model.rb +58 -0
  111. data/app/models/easy_ml/models/hyperparameters/base.rb +99 -0
  112. data/app/models/easy_ml/models/hyperparameters/xgboost/dart.rb +82 -0
  113. data/app/models/easy_ml/models/hyperparameters/xgboost/gblinear.rb +82 -0
  114. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +97 -0
  115. data/app/models/easy_ml/models/hyperparameters/xgboost.rb +71 -0
  116. data/app/models/easy_ml/models/xgboost/evals_callback.rb +138 -0
  117. data/app/models/easy_ml/models/xgboost/progress_callback.rb +39 -0
  118. data/app/models/easy_ml/models/xgboost.rb +544 -5
  119. data/app/models/easy_ml/prediction.rb +44 -0
  120. data/app/models/easy_ml/retraining_job.rb +278 -0
  121. data/app/models/easy_ml/retraining_run.rb +184 -0
  122. data/app/models/easy_ml/settings.rb +37 -0
  123. data/app/models/easy_ml/splitter.rb +90 -0
  124. data/app/models/easy_ml/splitters/base_splitter.rb +28 -0
  125. data/app/models/easy_ml/splitters/date_splitter.rb +91 -0
  126. data/app/models/easy_ml/splitters/predefined_splitter.rb +74 -0
  127. data/app/models/easy_ml/splitters/random_splitter.rb +82 -0
  128. data/app/models/easy_ml/tuner_job.rb +56 -0
  129. data/app/models/easy_ml/tuner_run.rb +31 -0
  130. data/app/models/splitter_history.rb +6 -0
  131. data/app/serializers/easy_ml/column_serializer.rb +27 -0
  132. data/app/serializers/easy_ml/dataset_serializer.rb +73 -0
  133. data/app/serializers/easy_ml/datasource_serializer.rb +64 -0
  134. data/app/serializers/easy_ml/feature_serializer.rb +27 -0
  135. data/app/serializers/easy_ml/model_serializer.rb +90 -0
  136. data/app/serializers/easy_ml/retraining_job_serializer.rb +22 -0
  137. data/app/serializers/easy_ml/retraining_run_serializer.rb +39 -0
  138. data/app/serializers/easy_ml/settings_serializer.rb +9 -0
  139. data/app/views/layouts/easy_ml/application.html.erb +15 -0
  140. data/config/initializers/resque.rb +3 -0
  141. data/config/resque-pool.yml +6 -0
  142. data/config/routes.rb +39 -0
  143. data/config/spring.rb +1 -0
  144. data/config/vite.json +15 -0
  145. data/lib/easy_ml/configuration.rb +64 -0
  146. data/lib/easy_ml/core/evaluators/base_evaluator.rb +53 -0
  147. data/lib/easy_ml/core/evaluators/classification_evaluators.rb +126 -0
  148. data/lib/easy_ml/core/evaluators/regression_evaluators.rb +66 -0
  149. data/lib/easy_ml/core/model_evaluator.rb +161 -89
  150. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +28 -18
  151. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +4 -25
  152. data/lib/easy_ml/core/tuner.rb +123 -62
  153. data/lib/easy_ml/core.rb +0 -3
  154. data/lib/easy_ml/core_ext/hash.rb +24 -0
  155. data/lib/easy_ml/core_ext/pathname.rb +11 -5
  156. data/lib/easy_ml/data/date_converter.rb +90 -0
  157. data/lib/easy_ml/data/filter_extensions.rb +31 -0
  158. data/lib/easy_ml/data/polars_column.rb +126 -0
  159. data/lib/easy_ml/data/polars_reader.rb +297 -0
  160. data/lib/easy_ml/data/preprocessor.rb +280 -142
  161. data/lib/easy_ml/data/simple_imputer.rb +255 -0
  162. data/lib/easy_ml/data/splits/file_split.rb +252 -0
  163. data/lib/easy_ml/data/splits/in_memory_split.rb +54 -0
  164. data/lib/easy_ml/data/splits/split.rb +95 -0
  165. data/lib/easy_ml/data/splits.rb +9 -0
  166. data/lib/easy_ml/data/statistics_learner.rb +93 -0
  167. data/lib/easy_ml/data/synced_directory.rb +341 -0
  168. data/lib/easy_ml/data.rb +6 -2
  169. data/lib/easy_ml/engine.rb +105 -6
  170. data/lib/easy_ml/feature_store.rb +227 -0
  171. data/lib/easy_ml/features.rb +61 -0
  172. data/lib/easy_ml/initializers/inflections.rb +17 -3
  173. data/lib/easy_ml/logging.rb +2 -2
  174. data/lib/easy_ml/predict.rb +74 -0
  175. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +192 -36
  176. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_column_histories.rb.tt +9 -0
  177. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_columns.rb.tt +25 -0
  178. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_dataset_histories.rb.tt +9 -0
  179. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasets.rb.tt +31 -0
  180. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasource_histories.rb.tt +9 -0
  181. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_datasources.rb.tt +16 -0
  182. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_deploys.rb.tt +24 -0
  183. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_events.rb.tt +20 -0
  184. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_feature_histories.rb.tt +14 -0
  185. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_features.rb.tt +32 -0
  186. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_file_histories.rb.tt +9 -0
  187. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_files.rb.tt +17 -0
  188. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_model_histories.rb.tt +9 -0
  189. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +20 -9
  190. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_predictions.rb.tt +17 -0
  191. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_retraining_jobs.rb.tt +77 -0
  192. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_settings.rb.tt +9 -0
  193. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitter_histories.rb.tt +9 -0
  194. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_splitters.rb.tt +15 -0
  195. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_tuner_jobs.rb.tt +40 -0
  196. data/lib/easy_ml/support/est.rb +5 -1
  197. data/lib/easy_ml/support/file_rotate.rb +79 -15
  198. data/lib/easy_ml/support/file_support.rb +9 -0
  199. data/lib/easy_ml/support/local_file.rb +24 -0
  200. data/lib/easy_ml/support/lockable.rb +62 -0
  201. data/lib/easy_ml/support/synced_file.rb +103 -0
  202. data/lib/easy_ml/support/utc.rb +5 -1
  203. data/lib/easy_ml/support.rb +6 -3
  204. data/lib/easy_ml/version.rb +4 -1
  205. data/lib/easy_ml.rb +7 -2
  206. metadata +355 -72
  207. data/app/models/easy_ml/models.rb +0 -5
  208. data/lib/easy_ml/core/model.rb +0 -30
  209. data/lib/easy_ml/core/model_core.rb +0 -181
  210. data/lib/easy_ml/core/models/hyperparameters/base.rb +0 -34
  211. data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +0 -19
  212. data/lib/easy_ml/core/models/xgboost.rb +0 -10
  213. data/lib/easy_ml/core/models/xgboost_core.rb +0 -220
  214. data/lib/easy_ml/core/models.rb +0 -10
  215. data/lib/easy_ml/core/uploaders/model_uploader.rb +0 -24
  216. data/lib/easy_ml/core/uploaders.rb +0 -7
  217. data/lib/easy_ml/data/dataloader.rb +0 -6
  218. data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +0 -31
  219. data/lib/easy_ml/data/dataset/data/sample_info.json +0 -1
  220. data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +0 -1
  221. data/lib/easy_ml/data/dataset/splits/file_split.rb +0 -140
  222. data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +0 -49
  223. data/lib/easy_ml/data/dataset/splits/split.rb +0 -98
  224. data/lib/easy_ml/data/dataset/splits.rb +0 -11
  225. data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +0 -43
  226. data/lib/easy_ml/data/dataset/splitters.rb +0 -9
  227. data/lib/easy_ml/data/dataset.rb +0 -430
  228. data/lib/easy_ml/data/datasource/datasource_factory.rb +0 -60
  229. data/lib/easy_ml/data/datasource/file_datasource.rb +0 -40
  230. data/lib/easy_ml/data/datasource/merged_datasource.rb +0 -64
  231. data/lib/easy_ml/data/datasource/polars_datasource.rb +0 -41
  232. data/lib/easy_ml/data/datasource/s3_datasource.rb +0 -89
  233. data/lib/easy_ml/data/datasource.rb +0 -33
  234. data/lib/easy_ml/data/preprocessor/preprocessor.rb +0 -205
  235. data/lib/easy_ml/data/preprocessor/simple_imputer.rb +0 -402
  236. data/lib/easy_ml/deployment.rb +0 -5
  237. data/lib/easy_ml/support/synced_directory.rb +0 -134
  238. data/lib/easy_ml/transforms.rb +0 -29
  239. /data/{lib/easy_ml/core → app/models/easy_ml}/models/hyperparameters.rb +0 -0
@@ -1,68 +1,602 @@
1
- require_relative "../../../lib/easy_ml/core/model"
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_models
4
+ #
5
+ # id :bigint not null, primary key
6
+ # name :string not null
7
+ # model_type :string
8
+ # status :string
9
+ # dataset_id :bigint
10
+ # model_file_id :bigint
11
+ # configuration :json
12
+ # version :string not null
13
+ # root_dir :string
14
+ # file :json
15
+ # sha :string
16
+ # last_trained_at :datetime
17
+ # is_training :boolean
18
+ # created_at :datetime not null
19
+ # updated_at :datetime not null
20
+ #
21
+ require_relative "models/hyperparameters"
22
+
2
23
  module EasyML
3
24
  class Model < ActiveRecord::Base
4
- if ActiveRecord::Base.connection.data_source_exists?("easy_ml_models")
5
- include EasyML::Core::ModelCore
6
-
7
- self.table_name = "easy_ml_models"
8
- else
9
- # Placeholder if the table doesn't exist (keeps the file quiet)
10
- def self.table_ready?
11
- false
25
+ self.table_name = "easy_ml_models"
26
+ include Historiographer::Silent
27
+ historiographer_mode :snapshot_only
28
+
29
+ include EasyML::Concerns::Configurable
30
+ include EasyML::Concerns::Versionable
31
+
32
+ self.filter_attributes += [:configuration]
33
+
34
+ MODEL_OPTIONS = {
35
+ "xgboost" => "EasyML::Models::XGBoost",
36
+ }
37
+ MODEL_TYPES = [
38
+ {
39
+ value: "xgboost",
40
+ label: "XGBoost",
41
+ description: "Extreme Gradient Boosting, a scalable and accurate implementation of gradient boosting machines",
42
+ },
43
+ ].freeze
44
+ MODEL_NAMES = MODEL_OPTIONS.keys.freeze
45
+ MODEL_CONSTANTS = MODEL_OPTIONS.values.map(&:constantize)
46
+
47
+ add_configuration_attributes :task, :objective, :hyperparameters, :evaluator, :callbacks, :metrics
48
+ MODEL_CONSTANTS.flat_map(&:configuration_attributes).each do |attribute|
49
+ add_configuration_attributes attribute
50
+ end
51
+
52
+ belongs_to :dataset
53
+ belongs_to :model_file, class_name: "EasyML::ModelFile", foreign_key: "model_file_id", optional: true
54
+
55
+ has_one :retraining_job, class_name: "EasyML::RetrainingJob"
56
+ accepts_nested_attributes_for :retraining_job
57
+ has_many :retraining_runs, class_name: "EasyML::RetrainingRun"
58
+ has_many :deploys, class_name: "EasyML::Deploy"
59
+
60
+ scope :deployed, -> { EasyML::ModelHistory.deployed }
61
+
62
+ def latest_deploy
63
+ deploys.order(id: :desc).limit(1).last
64
+ end
65
+
66
+ after_initialize :bump_version, if: -> { new_record? }
67
+ after_initialize :set_defaults, if: -> { new_record? }
68
+ before_save :save_model_file, if: -> { is_fit? && !is_history_class? && model_changed? && !@skip_save_model_file }
69
+
70
+ VALID_TASKS = %i[regression classification].freeze
71
+
72
+ TASK_TYPES = [
73
+ {
74
+ value: "classification",
75
+ label: "Classification",
76
+ description: "Predict categorical outcomes or class labels",
77
+ },
78
+ {
79
+ value: "regression",
80
+ label: "Regression",
81
+ description: "Predict continuous numerical values",
82
+ },
83
+ ].freeze
84
+
85
+ validates :name, presence: true
86
+ validates :name, uniqueness: { case_sensitive: false }
87
+ validates :task, presence: true
88
+ validates :task, inclusion: {
89
+ in: VALID_TASKS.map { |t| [t, t.to_s] }.flatten,
90
+ message: "must be one of: #{VALID_TASKS.join(", ")}",
91
+ }
92
+ validates :model_type, inclusion: { in: MODEL_NAMES }
93
+ validates :dataset_id, presence: true
94
+ validate :validate_metrics_allowed
95
+ before_save :set_root_dir
96
+
97
+ delegate :prepare_data, :preprocess, to: :adapter
98
+
99
+ STATUSES = %w[development inference retired]
100
+ STATUSES.each do |status|
101
+ define_method "#{status}?" do
102
+ self.status.to_sym == status.to_sym
12
103
  end
104
+ end
13
105
 
14
- Rails.logger.info("Skipping EasyML::Model definition as the 'easy_ml_models' table doesn't exist.")
106
+ def training?
107
+ is_training == true
15
108
  end
16
109
 
17
- scope :live, -> { where(is_live: true) }
18
- attribute :root_dir, :string
19
- after_initialize :apply_defaults
110
+ def train(async: true)
111
+ pending_run # Ensure we update the pending job before enqueuing in background so UI updates properly
112
+ update(is_training: true)
113
+ if async
114
+ EasyML::TrainingJob.perform_later(id)
115
+ else
116
+ actually_train
117
+ end
118
+ end
20
119
 
21
- validate :only_one_model_is_live?
22
- def only_one_model_is_live?
23
- return if @marking_live
120
+ def get_retraining_job
121
+ if retraining_job
122
+ self.evaluator = retraining_job.evaluator
123
+ evaluator = self.evaluator.symbolize_keys
124
+ else
125
+ default_eval = Core::ModelEvaluator.default_evaluator(task)
126
+ self.evaluator = default_eval
127
+ evaluator = default_eval
128
+ end
129
+
130
+ retraining_job || create_retraining_job(
131
+ model: self,
132
+ active: false,
133
+ evaluator: evaluator,
134
+ metric: evaluator[:metric],
135
+ direction: evaluator[:direction],
136
+ threshold: evaluator[:threshold],
137
+ frequency: "month",
138
+ at: { hour: 0, day_of_month: 1 },
139
+ )
140
+ end
24
141
 
25
- if previous_versions.live.count > 1
26
- raise "Multiple previous versions of #{name} are live! This should never happen. Update previous versions to is_live=false before proceeding"
142
+ def pending_run
143
+ job = get_retraining_job
144
+ job.retraining_runs.find_or_create_by(status: "pending", model: self)
145
+ end
146
+
147
+ def actually_train(&progress_block)
148
+ lock_model do
149
+ run = pending_run
150
+ run.wrap_training do
151
+ best_params = nil
152
+ if run.should_tune?
153
+ best_params = hyperparameter_search(&progress_block)
154
+ end
155
+ fit(&progress_block)
156
+ save
157
+ [self, best_params]
158
+ end
159
+ update(is_training: false)
160
+ run.reload
161
+ ensure
162
+ unlock!
163
+ end
164
+ end
165
+
166
+ def unlock!
167
+ Support::Lockable.unlock!(lock_key)
168
+ end
169
+
170
+ def lock_model
171
+ with_lock do |client|
172
+ yield
173
+ end
174
+ end
175
+
176
+ def with_lock
177
+ EasyML::Support::Lockable.with_lock(lock_key, stale_timeout: 60, resources: 1) do |client|
178
+ yield client
179
+ end
180
+ end
181
+
182
+ def lock_key
183
+ "training:#{self.name}:#{self.id}"
184
+ end
185
+
186
+ def hyperparameter_search(&progress_block)
187
+ tuner = retraining_job.tuner_config.symbolize_keys
188
+ extra_params = {
189
+ evaluator: evaluator,
190
+ model: self,
191
+ dataset: dataset,
192
+ }.compact
193
+ tuner.merge!(extra_params)
194
+ tuner_instance = EasyML::Core::Tuner.new(tuner)
195
+ tuner_instance.tune(&progress_block).tap do |best_params|
196
+ best_params.each do |key, value|
197
+ self.hyperparameters.send("#{key}=", value)
198
+ end
27
199
  end
200
+ end
201
+
202
+ def deployment_status
203
+ status
204
+ end
205
+
206
+ def formatted_model_type
207
+ adapter.class.name.split("::").last
208
+ end
209
+
210
+ def formatted_version
211
+ return nil unless version
212
+ Time.strptime(version, "%Y%m%d%H%M%S").strftime("%B %-d, %Y at %-l:%M %p")
213
+ end
28
214
 
29
- return unless previous_versions.live.any? && is_live
215
+ def last_run_at
216
+ last_run&.created_at
217
+ end
218
+
219
+ def last_run
220
+ retraining_runs.order(id: :desc).limit(1).last
221
+ end
222
+
223
+ def inference_version
224
+ latest_deploy&.model_version
225
+ end
226
+
227
+ alias_method :current_version, :inference_version
228
+ alias_method :latest_version, :inference_version
229
+ alias_method :deployed, :inference_version
30
230
 
31
- errors.add(:is_live,
32
- "cannot mark model live when previous version is live. Explicitly use the mark_live method to mark this as the live version")
231
+ def hyperparameters
232
+ @hypers ||= adapter.build_hyperparameters(@hyperparameters)
33
233
  end
34
234
 
35
- def mark_live
36
- transaction do
37
- self.class.where(name: name).where.not(id: id).update_all(is_live: false)
38
- self.class.where(id: id).update_all(is_live: true)
235
+ def callbacks
236
+ @cbs ||= adapter.build_callbacks(@callbacks)
237
+ end
238
+
239
+ def predict(xs)
240
+ load_model!
241
+ adapter.predict(xs)
242
+ end
243
+
244
+ def save_model_file
245
+ raise "No trained model! Need to train model before saving (call model.fit)" unless is_fit?
246
+ return unless adapter.loaded?
247
+
248
+ model_file = get_model_file
249
+
250
+ bump_version(force: true)
251
+ path = model_file.full_path(version)
252
+ full_path = adapter.save_model_file(path)
253
+ model_file.upload(full_path)
254
+
255
+ model_file.save
256
+ self.model_file = model_file
257
+ cleanup
258
+ end
259
+
260
+ def feature_names
261
+ adapter.feature_names
262
+ end
263
+
264
+ def cleanup!
265
+ get_model_file&.cleanup!
266
+ end
267
+
268
+ def cleanup
269
+ get_model_file&.cleanup(files_to_keep)
270
+ end
271
+
272
+ def loaded?
273
+ model_file = get_model_file
274
+ return false if model_file.persisted? && !File.exist?(model_file.full_path.to_s)
275
+
276
+ file_exists = true
277
+ if model_file.present? && model_file.persisted? && model_file.full_path.present?
278
+ file_exists = File.exist?(model_file.full_path)
39
279
  end
280
+
281
+ loaded = adapter.loaded?
282
+ load_model_file unless loaded
283
+ file_exists && adapter.loaded?
40
284
  end
41
285
 
42
- def previous_versions
43
- EasyML::Model.where(name: name).order(id: :desc)
286
+ def model_changed?
287
+ return false unless is_fit?
288
+ return true if inference_version.nil?
289
+ return true if model_file.present? && !model_file.persisted?
290
+ return true if model_file.present? && model_file.fit? && inference_version.nil?
291
+
292
+ adapter.model_changed?(inference_version.sha)
293
+ end
294
+
295
+ def feature_importances
296
+ adapter.feature_importances
297
+ end
298
+
299
+ def fit_in_batches?
300
+ retraining_job.present? && retraining_job.batch_mode == true
301
+ end
302
+
303
+ def fit(tuning: false, x_train: nil, y_train: nil, x_valid: nil, y_valid: nil, &progress_block)
304
+ return fit_in_batches(**batch_args.merge!(tuning: tuning), &progress_block) if fit_in_batches?
305
+
306
+ dataset.refresh
307
+ adapter.fit(tuning: tuning, x_train: x_train, y_train: y_train, x_valid: x_valid, y_valid: y_valid, &progress_block)
308
+ @is_fit = true
309
+ end
310
+
311
+ def batch_args
312
+ defaults = {
313
+ batch_size: 1024,
314
+ batch_overlap: 3,
315
+ batch_key: nil,
316
+ }
317
+ overrides = { batch_size: retraining_job&.batch_size, batch_overlap: retraining_job&.batch_overlap, batch_key: retraining_job&.batch_key }.compact
318
+ defaults.merge!(overrides)
319
+ end
320
+
321
+ def batch_mode
322
+ retraining_job&.batch_mode || false
323
+ end
324
+
325
+ def prepare_callbacks(tune_started_at)
326
+ adapter.prepare_callbacks(tune_started_at)
327
+ end
328
+
329
+ def after_tuning
330
+ adapter.after_tuning
331
+ end
332
+
333
+ def fit_in_batches(tuning: false, batch_size: nil, batch_overlap: nil, batch_key: nil, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"), &progress_block)
334
+ adapter.fit_in_batches(tuning: tuning, batch_size: batch_size, batch_overlap: batch_overlap, batch_key: batch_key, checkpoint_dir: checkpoint_dir, &progress_block)
335
+ @is_fit = true
336
+ end
337
+
338
+ attr_accessor :is_fit
339
+
340
+ def is_fit?
341
+ model_file = get_model_file
342
+ return true if model_file.present? && model_file.fit?
343
+
344
+ adapter.is_fit?
345
+ end
346
+
347
+ def deployable?
348
+ cannot_deploy_reasons.none?
349
+ end
350
+
351
+ def decode_labels(ys, col: nil)
352
+ dataset.decode_labels(ys, col: col)
353
+ end
354
+
355
+ def evaluate(y_pred: nil, y_true: nil, x_true: nil, evaluator: nil)
356
+ evaluator ||= self.evaluator
357
+ if y_pred.nil?
358
+ inputs = default_evaluation_inputs
359
+ y_pred = inputs[:y_pred]
360
+ y_true = inputs[:y_true]
361
+ x_true = inputs[:x_true]
362
+ end
363
+ EasyML::Core::ModelEvaluator.evaluate(model: self, y_pred: y_pred, y_true: y_true, x_true: x_true, evaluator: evaluator)
364
+ end
365
+
366
+ def evaluator
367
+ instance_variable_get(:@evaluator) || default_evaluator
368
+ end
369
+
370
+ def default_evaluator
371
+ return nil unless task.present?
372
+
373
+ EasyML::Core::ModelEvaluator.default_evaluator(task)
374
+ end
375
+
376
+ def get_params
377
+ @hyperparameters.to_h
378
+ end
379
+
380
+ def evals
381
+ last_run&.metrics || {}
382
+ end
383
+
384
+ def metric_accessor(metric)
385
+ metrics = last_run.metrics.symbolize_keys
386
+ metrics.dig(metric.to_sym)
387
+ end
388
+
389
+ EasyML::Core::ModelEvaluator.metrics.each do |metric_name|
390
+ define_method metric_name do
391
+ metric_accessor(metric_name)
392
+ end
393
+ end
394
+
395
+ EasyML::Core::ModelEvaluator.callbacks = lambda do |metric_name|
396
+ EasyML::Model.define_method metric_name do
397
+ metric_accessor(metric_name)
398
+ end
399
+ end
400
+
401
+ def allowed_metrics
402
+ EasyML::Core::ModelEvaluator.metrics(task).map(&:to_s)
403
+ end
404
+
405
+ def default_metrics
406
+ return [] unless task.present?
407
+
408
+ case task.to_sym
409
+ when :regression
410
+ %w[mean_absolute_error mean_squared_error root_mean_squared_error r2_score]
411
+ when :classification
412
+ %w[accuracy_score precision_score recall_score f1_score]
413
+ else
414
+ []
415
+ end
416
+ end
417
+
418
+ def self.constants
419
+ {
420
+ objectives: objectives_by_model_type,
421
+ metrics: metrics_by_task,
422
+ tasks: TASK_TYPES,
423
+ timezone: EasyML::Configuration.timezone_label,
424
+ retraining_job_constants: EasyML::RetrainingJob.constants,
425
+ tuner_job_constants: EasyML::TunerJob.constants,
426
+ }
427
+ end
428
+
429
+ def self.metrics_by_task
430
+ EasyML::Core::ModelEvaluator.metrics_by_task
431
+ end
432
+
433
+ def self.objectives_by_model_type
434
+ MODEL_OPTIONS.inject({}) do |h, (k, v)|
435
+ h.tap do
436
+ h[k] = v.constantize.const_get(:OBJECTIVES_FRONTEND)
437
+ end
438
+ end.deep_symbolize_keys
439
+ end
440
+
441
+ def attributes
442
+ super.merge!(
443
+ hyperparameters: hyperparameters.to_h,
444
+ )
445
+ end
446
+
447
+ class CannotdeployError < StandardError
448
+ end
449
+
450
+ def deploy(async: true)
451
+ last_run.deploy(async: async)
452
+ end
453
+
454
+ def actually_deploy
455
+ raise CannotdeployError, cannot_deploy_reasons.first if cannot_deploy_reasons.any?
456
+
457
+ # Prepare the inference model by freezing + saving the model, dataset, and datasource
458
+ # (This creates ModelHistory, DatasetHistory, etc)
459
+ save_model_file
460
+ self.sha = model_file.sha
461
+ save
462
+ dataset.upload_remote_files
463
+ snapshot.tap do
464
+ # Prepare the model to be retrained (reset values so they don't conflict with our snapshotted version)
465
+ bump_version(force: true)
466
+ dataset.bump_versions(version)
467
+ self.model_file = new_model_file!
468
+ save
469
+ end
470
+ end
471
+
472
+ def cannot_deploy_reasons
473
+ [
474
+ is_fit? ? nil : "Model has not been trained",
475
+ dataset.target.present? ? nil : "Dataset has no target",
476
+ !dataset.datasource.in_memory? ? nil : "Cannot perform inference using an in-memory datasource",
477
+ ].compact
478
+ end
479
+
480
+ def root_dir=(value)
481
+ raise "Cannot override value of root_dir!" unless value.to_s == root_dir.to_s
482
+
483
+ write_attribute(:root_dir, value)
484
+ end
485
+
486
+ def set_root_dir
487
+ write_attribute(:root_dir, root_dir)
488
+ end
489
+
490
+ def root_dir
491
+ EasyML::Engine.root_dir.join("models").join(underscored_name).to_s
492
+ end
493
+
494
+ def load_model(force: false)
495
+ download_model_file(force: force)
496
+ load_model_file
497
+ end
498
+
499
+ def metrics=(value)
500
+ value = [value] unless value.is_a?(Array)
501
+ value = value.map(&:to_s)
502
+ value = value.uniq
503
+ @metrics = value
504
+ end
505
+
506
+ def adapter
507
+ @adapter ||= begin
508
+ adapter_class = MODEL_OPTIONS[model_type]
509
+ raise "Don't know how to use model adapter #{model_type}!" unless adapter_class.present?
510
+
511
+ adapter_class.constantize.new(self)
512
+ end
44
513
  end
45
514
 
46
515
  private
47
516
 
517
+ def default_evaluation_inputs
518
+ x_true, y_true = dataset.test(split_ys: true)
519
+ y_pred = predict(x_true)
520
+ {
521
+ x_true: x_true,
522
+ y_true: y_true,
523
+ y_pred: y_pred,
524
+ }
525
+ end
526
+
527
+ def underscored_name
528
+ name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
529
+ end
530
+
531
+ def get_model_file
532
+ model_file || new_model_file!
533
+ end
534
+
535
+ def new_model_file!
536
+ build_model_file(
537
+ root_dir: root_dir,
538
+ model: self,
539
+ s3_bucket: EasyML::Configuration.s3_bucket,
540
+ s3_region: EasyML::Configuration.s3_region,
541
+ s3_access_key_id: EasyML::Configuration.s3_access_key_id,
542
+ s3_secret_access_key: EasyML::Configuration.s3_secret_access_key,
543
+ s3_prefix: prefix,
544
+ )
545
+ end
546
+
547
+ def prefix
548
+ s3_prefix = EasyML::Configuration.s3_prefix
549
+ s3_prefix.present? ? File.join(s3_prefix, name) : name
550
+ end
551
+
552
+ def load_model!
553
+ load_model(force: true)
554
+ load_dataset
555
+ end
556
+
557
+ def load_dataset
558
+ dataset.load_dataset
559
+ end
560
+
561
+ def load_model_file
562
+ return unless model_file&.full_path && File.exist?(model_file.full_path)
563
+
564
+ adapter.load_model_file(model_file.full_path)
565
+ end
566
+
567
+ def download_model_file(force: false)
568
+ return unless persisted?
569
+ return if loaded? && !force
570
+
571
+ get_model_file.download
572
+ end
573
+
48
574
  def files_to_keep
49
- live_models = self.class.live
575
+ inference_models = EasyML::ModelHistory.deployed
576
+ training_models = EasyML::Model.all
50
577
 
51
- recent_copies = live_models.flat_map do |live|
52
- # Fetch all models with the same name
53
- self.class.where(name: live.name).where(is_live: false).order(created_at: :desc).limit(live.name == name ? 4 : 5)
54
- end
578
+ ([self] + training_models + inference_models).compact.map(&:model_file).compact.map(&:full_path).uniq
579
+ end
55
580
 
56
- recent_versions = self.class
57
- .where.not(
58
- "EXISTS (SELECT 1 FROM easy_ml_models e2 WHERE e2.name = easy_ml_models.name AND e2.is_live = true)"
59
- )
60
- .where("created_at >= ?", 2.days.ago)
61
- .order(created_at: :desc)
62
- .group_by(&:name)
63
- .flat_map { |_, models| models.take(5) }
581
+ def underscored_name
582
+ name = self.name || self.class.name.split("::").last
583
+ name.gsub(/\s{2,}/, " ").gsub(/\s/, "_").downcase
584
+ end
64
585
 
65
- ([self] + recent_versions + recent_copies + live_models).compact.map(&:file).map(&:path).uniq
586
+ def set_defaults
587
+ self.model_type ||= "xgboost"
588
+ self.status ||= :training
589
+ self.metrics ||= default_metrics
590
+ end
591
+
592
+ def validate_metrics_allowed
593
+ unknown_metrics = metrics.select { |metric| allowed_metrics.exclude?(metric) }
594
+ return unless unknown_metrics.any?
595
+
596
+ errors.add(:metrics,
597
+ "don't know how to handle #{"metrics".pluralize(unknown_metrics)} #{unknown_metrics.join(", ")}, use EasyML::Core::ModelEvaluator.register(:name, Evaluator, :regression|:classification)")
66
598
  end
67
599
  end
68
600
  end
601
+
602
+ require_relative "models/xgboost"