easy_ml 0.2.0.pre.rc77 → 0.2.0.pre.rc81

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. checksums.yaml +4 -4
  2. data/app/controllers/easy_ml/datasets_controller.rb +3 -3
  3. data/app/controllers/easy_ml/models_controller.rb +4 -3
  4. data/app/frontend/components/ModelForm.tsx +16 -0
  5. data/app/frontend/components/ScheduleModal.tsx +0 -2
  6. data/app/frontend/components/dataset/PreprocessingConfig.tsx +7 -6
  7. data/app/jobs/easy_ml/application_job.rb +1 -0
  8. data/app/jobs/easy_ml/batch_job.rb +47 -6
  9. data/app/jobs/easy_ml/compute_feature_job.rb +10 -10
  10. data/app/jobs/easy_ml/reaper.rb +14 -10
  11. data/app/jobs/easy_ml/refresh_dataset_job.rb +2 -0
  12. data/app/jobs/easy_ml/sync_datasource_job.rb +1 -0
  13. data/app/models/concerns/easy_ml/dataframe_serialization.rb +1 -17
  14. data/app/models/easy_ml/column/imputers/base.rb +1 -1
  15. data/app/models/easy_ml/column/imputers/imputer.rb +2 -0
  16. data/app/models/easy_ml/column/imputers/today.rb +1 -1
  17. data/app/models/easy_ml/column/selector.rb +0 -8
  18. data/app/models/easy_ml/column.rb +1 -1
  19. data/app/models/easy_ml/column_list.rb +2 -3
  20. data/app/models/easy_ml/dataset/learner/base.rb +2 -2
  21. data/app/models/easy_ml/dataset/learner/eager.rb +3 -1
  22. data/app/models/easy_ml/dataset/learner/lazy.rb +4 -1
  23. data/app/models/easy_ml/dataset.rb +47 -38
  24. data/app/models/easy_ml/datasource.rb +0 -6
  25. data/app/models/easy_ml/feature.rb +33 -8
  26. data/app/models/easy_ml/model.rb +27 -4
  27. data/app/models/easy_ml/models/hyperparameters/xgboost/gbtree.rb +21 -5
  28. data/app/models/easy_ml/models/xgboost/evals_callback.rb +9 -5
  29. data/app/models/easy_ml/models/xgboost.rb +58 -36
  30. data/app/models/easy_ml/retraining_run.rb +1 -1
  31. data/app/serializers/easy_ml/model_serializer.rb +1 -0
  32. data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +16 -3
  33. data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +0 -17
  34. data/lib/easy_ml/core/tuner.rb +14 -5
  35. data/lib/easy_ml/data/dataset_manager/reader/base.rb +12 -0
  36. data/lib/easy_ml/data/dataset_manager/reader/data_frame.rb +8 -3
  37. data/lib/easy_ml/data/dataset_manager/reader/file.rb +5 -0
  38. data/lib/easy_ml/data/dataset_manager/reader.rb +7 -1
  39. data/lib/easy_ml/data/dataset_manager/writer/base.rb +26 -9
  40. data/lib/easy_ml/data/dataset_manager/writer.rb +5 -1
  41. data/lib/easy_ml/data/dataset_manager.rb +18 -4
  42. data/lib/easy_ml/data/embeddings/adapters.rb +56 -0
  43. data/lib/easy_ml/data/embeddings/compression.rb +0 -0
  44. data/lib/easy_ml/data/embeddings.rb +43 -0
  45. data/lib/easy_ml/data/polars_column.rb +19 -5
  46. data/lib/easy_ml/engine.rb +16 -14
  47. data/lib/easy_ml/feature_store.rb +19 -16
  48. data/lib/easy_ml/support/lockable.rb +1 -5
  49. data/lib/easy_ml/version.rb +1 -1
  50. data/public/easy_ml/assets/.vite/manifest.json +1 -1
  51. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-Bbf3mD_b.js +522 -0
  52. data/public/easy_ml/assets/assets/entrypoints/{Application.tsx-B1qLZuyu.js.map → Application.tsx-Bbf3mD_b.js.map} +1 -1
  53. metadata +9 -7
  54. data/app/models/easy_ml/datasources/polars_datasource.rb +0 -69
  55. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-B1qLZuyu.js +0 -522
@@ -180,6 +180,8 @@ module EasyML
180
180
  EasyML::Reaper.kill(EasyML::RefreshDatasetJob, id)
181
181
  update(workflow_status: :ready)
182
182
  unlock!
183
+ features.update_all(needs_fit: true, workflow_status: "ready")
184
+ features.each(&:wipe)
183
185
  end
184
186
 
185
187
  def refresh_async
@@ -201,12 +203,6 @@ module EasyML
201
203
  @raw = initialize_split("raw")
202
204
  end
203
205
 
204
- def clipped
205
- return @clipped if @clipped && @clipped.dataset
206
-
207
- @clipped = initialize_split("clipped")
208
- end
209
-
210
206
  def processed
211
207
  return @processed if @processed && @processed.dataset
212
208
 
@@ -236,20 +232,20 @@ module EasyML
236
232
  cleanup
237
233
  refresh_datasource!
238
234
  split_data
239
- process_data
235
+ fit
240
236
  end
241
237
 
242
238
  def prepare
243
239
  prepare_features
244
240
  refresh_datasource
245
241
  split_data
246
- process_data
242
+ fit
247
243
  end
248
244
 
249
245
  def actually_refresh
250
246
  refreshing do
251
- learn(delete: false) # After syncing datasource, learn new statistics + sync columns
252
- process_data
247
+ fit
248
+ normalize_all
253
249
  fully_reload
254
250
  learn
255
251
  learn_statistics(type: :processed) # After processing data, we learn any new statistics
@@ -287,6 +283,7 @@ module EasyML
287
283
 
288
284
  def fit_features(async: false, features: self.features, force: false)
289
285
  features_to_compute = force ? features : features.needs_fit
286
+ puts "Features to compute.... #{features_to_compute}"
290
287
  return after_fit_features if features_to_compute.empty?
291
288
 
292
289
  features.first.fit(features: features_to_compute, async: async)
@@ -295,10 +292,12 @@ module EasyML
295
292
  measure_method_timing :fit_features
296
293
 
297
294
  def after_fit_features
295
+ puts "After fit features"
298
296
  unlock!
299
297
  reload
300
298
  return if failed?
301
299
 
300
+ puts "Actually refresh..."
302
301
  actually_refresh
303
302
  end
304
303
 
@@ -385,6 +384,8 @@ module EasyML
385
384
 
386
385
  def unlock!
387
386
  Support::Lockable.unlock!(lock_key)
387
+ features.each(&:unlock!)
388
+ true
388
389
  end
389
390
 
390
391
  def locked?
@@ -427,12 +428,6 @@ module EasyML
427
428
  (read_attribute(:statistics) || {}).with_indifferent_access
428
429
  end
429
430
 
430
- def process_data
431
- learn(delete: false)
432
- fit
433
- normalize_all
434
- end
435
-
436
431
  def needs_learn?
437
432
  return true if columns_need_refresh?
438
433
 
@@ -483,13 +478,31 @@ module EasyML
483
478
  df = apply_missing_columns(df, inference: inference)
484
479
  df = columns.transform(df, inference: inference)
485
480
  df = apply_features(df, features)
486
- df = columns.transform(df, inference: inference, computed: true)
481
+ df = columns.transform(df, inference: inference)
487
482
  df = apply_column_mask(df, inference: inference) unless all_columns
488
483
  df = drop_nulls(df) unless inference
489
484
  df, = processed.split_features_targets(df, true, target) if split_ys
490
485
  df
491
486
  end
492
487
 
488
+ # Massage out one-hot cats to their canonical name
489
+ #
490
+ # Takes: ["Sex_male", "Sex_female", "Embarked_c", "PassengerId"]
491
+ # Returns: ["Embarked", "Sex", "PassengerId"]
492
+ def regular_columns(col_list)
493
+ one_hot_cats = columns.allowed_categories.invert.reduce({}) do |h, (k, v)|
494
+ h.tap do
495
+ k.each do |k2|
496
+ h["#{v}_#{k2}"] = v
497
+ end
498
+ end
499
+ end
500
+
501
+ col_list.map do |col|
502
+ one_hot_cats.key?(col) ? one_hot_cats[col] : col
503
+ end.uniq.sort
504
+ end
505
+
493
506
  measure_method_timing :normalize
494
507
 
495
508
  def missing_required_fields(df)
@@ -537,7 +550,6 @@ module EasyML
537
550
 
538
551
  def cleanup
539
552
  raw.cleanup
540
- clipped.cleanup
541
553
  processed.cleanup
542
554
  end
543
555
 
@@ -705,6 +717,20 @@ module EasyML
705
717
  reload
706
718
  end
707
719
 
720
+ def list_nulls(input = nil, list_raw = false)
721
+ input = data(lazy: true) if input.nil?
722
+
723
+ case input
724
+ when Polars::DataFrame
725
+ input = input.lazy
726
+ when String, Symbol
727
+ input = input.to_sym
728
+ input = send(input).data(lazy: true)
729
+ end
730
+ col_list = EasyML::Data::DatasetManager.list_nulls(input)
731
+ list_raw ? col_list : regular_columns(col_list)
732
+ end
733
+
708
734
  private
709
735
 
710
736
  def apply_date_splitter_config
@@ -730,10 +756,8 @@ module EasyML
730
756
 
731
757
  def initialize_splits
732
758
  @raw = nil
733
- @clipped = nil
734
759
  @processed = nil
735
760
  raw
736
- clipped
737
761
  processed
738
762
  end
739
763
 
@@ -778,11 +802,12 @@ module EasyML
778
802
  processed.cleanup
779
803
 
780
804
  SPLIT_ORDER.each do |segment|
781
- df = clipped.read(segment)
805
+ df = raw.read(segment)
782
806
  learn_computed_columns(df) if segment == :train
783
807
  processed_df = normalize(df, all_columns: true)
784
808
  processed.save(segment, processed_df)
785
809
  end
810
+ features.select { |f| !f.fittable? }.each(&:after_transform)
786
811
  @normalized = true
787
812
  end
788
813
 
@@ -825,26 +850,10 @@ module EasyML
825
850
  end
826
851
 
827
852
  def fit
828
- apply_clip
853
+ learn(delete: false)
829
854
  learn_statistics(type: :raw)
830
855
  end
831
856
 
832
- def apply_clip
833
- clipped.cleanup
834
-
835
- SPLIT_ORDER.each do |segment|
836
- df = raw.send(segment, lazy: true, all_columns: true)
837
- clipped.save(
838
- segment,
839
- columns.apply_clip(df) # Ensuring this returns a LazyFrame means we'll automatically use sink_parquet
840
- )
841
- end
842
- end
843
-
844
- measure_method_timing :apply_clip
845
-
846
- # log_method :fit, "Learning statistics", verbose: true
847
-
848
857
  def split_data!
849
858
  split_data(force: true)
850
859
  end
@@ -22,7 +22,6 @@ module EasyML
22
22
  DATASOURCE_OPTIONS = {
23
23
  "s3" => "EasyML::Datasources::S3Datasource",
24
24
  "file" => "EasyML::Datasources::FileDatasource",
25
- "polars" => "EasyML::Datasources::PolarsDatasource",
26
25
  }
27
26
  DATASOURCE_TYPES = [
28
27
  {
@@ -35,11 +34,6 @@ module EasyML
35
34
  label: "Local Files",
36
35
  description: "Connect to data stored in local files",
37
36
  },
38
- {
39
- value: "polars",
40
- label: "Polars DataFrame",
41
- description: "In-memory dataframe storage using Polars",
42
- },
43
37
  ].freeze
44
38
  DATASOURCE_NAMES = DATASOURCE_OPTIONS.keys.freeze
45
39
  DATASOURCE_CONSTANTS = DATASOURCE_OPTIONS.values.map(&:constantize)
@@ -78,16 +78,24 @@ module EasyML
78
78
  scope :never_applied, -> { where(applied_at: nil) }
79
79
  scope :never_fit, -> do
80
80
  fittable = where(fit_at: nil)
81
- fittable = fittable.select { |f| f.adapter.respond_to?(:fit) }
81
+ fittable = fittable.select(&:fittable?)
82
82
  where(id: fittable.map(&:id))
83
83
  end
84
84
  scope :needs_fit, -> { has_changes.or(never_applied).or(never_fit) }
85
- scope :ready_to_apply, -> { where(needs_fit: false).where.not(id: has_changes.map(&:id)) }
85
+ scope :ready_to_apply, -> do
86
+ base = where(needs_fit: false).where.not(id: has_changes.map(&:id))
87
+ doesnt_fit = where_no_fit
88
+ where(id: base.map(&:id).concat(doesnt_fit.map(&:id)))
89
+ end
90
+
91
+ scope :fittable, -> { all.select(&:fittable?) }
92
+ scope :where_no_fit, -> { all.reject(&:fittable?) }
86
93
 
87
94
  before_save :apply_defaults, if: :new_record?
88
95
  before_save :update_sha
89
96
  after_find :update_from_feature_class
90
97
  before_save :update_from_feature_class
98
+ before_destroy :wipe
91
99
 
92
100
  def feature_klass
93
101
  feature_class.constantize
@@ -99,6 +107,10 @@ module EasyML
99
107
  feature_klass.present?
100
108
  end
101
109
 
110
+ def fittable?
111
+ adapter.respond_to?(:fit)
112
+ end
113
+
102
114
  def adapter
103
115
  @adapter ||= feature_klass.new
104
116
  end
@@ -197,7 +209,7 @@ module EasyML
197
209
  end
198
210
 
199
211
  EasyML::Data::Partition::Boundaries.new(
200
- reader.data(lazy: true),
212
+ reader.data(lazy: true, all_columns: true),
201
213
  primary_key,
202
214
  batch_size
203
215
  ).to_a.map.with_index do |partition, idx|
@@ -207,18 +219,23 @@ module EasyML
207
219
  batch_end: partition[:partition_end],
208
220
  batch_number: feature_position,
209
221
  subbatch_number: idx,
210
- parent_batch_id: Random.uuid,
211
222
  }
212
223
  end
213
224
  end
214
225
 
215
226
  def wipe
227
+ update(needs_fit: true) if fittable?
216
228
  feature_store.wipe
217
229
  end
218
230
 
219
231
  def fit(features: [self], async: false)
220
232
  ordered_features = features.sort_by(&:feature_position)
221
- jobs = ordered_features.map(&:build_batches)
233
+ parent_batch_id = Random.uuid
234
+ jobs = ordered_features.select(&:fittable?).map do |feature|
235
+ feature.build_batches.map do |batch_args|
236
+ batch_args.merge(parent_batch_id: parent_batch_id)
237
+ end
238
+ end
222
239
  job_count = jobs.dup.flatten.size
223
240
 
224
241
  ordered_features.each(&:wipe)
@@ -445,7 +462,7 @@ module EasyML
445
462
  def after_fit
446
463
  update_sha
447
464
 
448
- feature_store.compact
465
+ feature_store.compact if fittable?
449
466
  updates = {
450
467
  fit_at: Time.current,
451
468
  needs_fit: false,
@@ -454,6 +471,14 @@ module EasyML
454
471
  update!(updates)
455
472
  end
456
473
 
474
+ def after_transform
475
+ feature_store.compact if !fittable?
476
+ end
477
+
478
+ def unlock!
479
+ feature_store.unlock!
480
+ end
481
+
457
482
  UNCONFIGURABLE_COLUMNS = %w(
458
483
  id
459
484
  dataset_id
@@ -508,14 +533,14 @@ module EasyML
508
533
  new_sha = compute_sha
509
534
  if new_sha != self.sha
510
535
  self.sha = new_sha
511
- self.needs_fit = true
536
+ self.needs_fit = fittable?
512
537
  end
513
538
  end
514
539
 
515
540
  def update_from_feature_class
516
541
  if read_attribute(:batch_size) != config.dig(:batch_size)
517
542
  write_attribute(:batch_size, config.dig(:batch_size))
518
- self.needs_fit = true
543
+ self.needs_fit = fittable?
519
544
  end
520
545
 
521
546
  if self.primary_key != config.dig(:primary_key)
@@ -45,7 +45,7 @@ module EasyML
45
45
  MODEL_NAMES = MODEL_OPTIONS.keys.freeze
46
46
  MODEL_CONSTANTS = MODEL_OPTIONS.values.map(&:constantize)
47
47
 
48
- add_configuration_attributes :task, :objective, :hyperparameters, :callbacks, :metrics
48
+ add_configuration_attributes :task, :objective, :hyperparameters, :callbacks, :metrics, :weights_column
49
49
  MODEL_CONSTANTS.flat_map(&:configuration_attributes).each do |attribute|
50
50
  add_configuration_attributes attribute
51
51
  end
@@ -182,12 +182,15 @@ module EasyML
182
182
  lock_model do
183
183
  run = pending_run
184
184
  run.wrap_training do
185
+ raise untrainable_error unless trainable?
186
+
185
187
  best_params = nil
186
188
  if run.should_tune?
187
189
  best_params = hyperparameter_search(&progress_block)
190
+ else
191
+ fit(&progress_block)
192
+ save
188
193
  end
189
- fit(&progress_block)
190
- save
191
194
  [self, best_params]
192
195
  end
193
196
  update(is_training: false)
@@ -258,7 +261,7 @@ module EasyML
258
261
 
259
262
  def formatted_version
260
263
  return nil unless version
261
- Time.strptime(version, "%Y%m%d%H%M%S").strftime("%B %-d, %Y at %-l:%M %p")
264
+ UTC.parse(version).in_time_zone(EasyML::Configuration.timezone).strftime("%B %-d, %Y at %-l:%M %p")
262
265
  end
263
266
 
264
267
  def last_run_at
@@ -277,6 +280,22 @@ module EasyML
277
280
  alias_method :latest_version, :inference_version
278
281
  alias_method :deployed, :inference_version
279
282
 
283
+ def trainable?
284
+ adapter.trainable?
285
+ end
286
+
287
+ def untrainable_columns
288
+ adapter.untrainable_columns
289
+ end
290
+
291
+ def untrainable_error
292
+ %Q(
293
+ Cannot train dataset containing null values!
294
+ Apply preprocessing to the following columns:
295
+ #{untrainable_columns.join(", ")}
296
+ )
297
+ end
298
+
280
299
  def predict(xs)
281
300
  load_model!
282
301
  unless xs.is_a?(XGBoost::DMatrix)
@@ -375,6 +394,10 @@ module EasyML
375
394
  adapter.after_tuning
376
395
  end
377
396
 
397
+ def cleanup
398
+ adapter.cleanup
399
+ end
400
+
378
401
  def fit_in_batches(tuning: false, batch_size: nil, batch_overlap: nil, batch_key: nil, checkpoint_dir: Rails.root.join("tmp", "xgboost_checkpoints"), &progress_block)
379
402
  adapter.fit_in_batches(tuning: tuning, batch_size: batch_size, batch_overlap: batch_overlap, batch_key: batch_key, checkpoint_dir: checkpoint_dir, &progress_block)
380
403
  end
@@ -37,6 +37,20 @@ module EasyML
37
37
  max: 10,
38
38
  step: 0.1,
39
39
  },
40
+ scale_pos_weight: {
41
+ label: "Scale Pos Weight",
42
+ description: "Balance of positive and negative weights",
43
+ min: 0,
44
+ max: 200,
45
+ step: 1,
46
+ },
47
+ max_delta_step: {
48
+ label: "Max Delta Step",
49
+ description: "Maximum delta step",
50
+ min: 0,
51
+ max: 10,
52
+ step: 1,
53
+ },
40
54
  gamma: {
41
55
  label: "Gamma",
42
56
  description: "Minimum loss reduction required to make a further partition",
@@ -81,11 +95,13 @@ module EasyML
81
95
  label: "Histogram",
82
96
  description: "Fast histogram optimized approximate greedy algorithm",
83
97
  },
84
- {
85
- value: "gpu_hist",
86
- label: "GPU Histogram",
87
- description: "GPU implementation of hist algorithm",
88
- },
98
+ # Only when compiled wih GPU support...
99
+ # How to make this not a default optoin
100
+ # {
101
+ # value: "gpu_hist",
102
+ # label: "GPU Histogram",
103
+ # description: "GPU implementation of hist algorithm",
104
+ # },
89
105
  ],
90
106
  },
91
107
  )
@@ -36,7 +36,7 @@ module EasyML
36
36
  if tuner.present?
37
37
  [tuner.x_valid, tuner.y_valid]
38
38
  else
39
- model.dataset.valid(split_ys: true)
39
+ model.dataset.valid(split_ys: true, lazy: true)
40
40
  end
41
41
  end
42
42
 
@@ -47,7 +47,8 @@ module EasyML
47
47
  if epoch % log_frequency == 0
48
48
  model.adapter.external_model = booster
49
49
  x_valid, y_valid = valid_dataset
50
- @preprocessed ||= model.preprocess(x_valid)
50
+ x_valid = x_valid.select(model.dataset.col_order(inference: true))
51
+ @preprocessed ||= model.preprocess(x_valid, y_valid)
51
52
  y_pred = model.predict(@preprocessed)
52
53
  dataset = model.dataset.valid(all_columns: true)
53
54
 
@@ -102,7 +103,7 @@ module EasyML
102
103
  model.callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
103
104
  end
104
105
 
105
- def track_cumulative_feature_importance(finish = true)
106
+ def track_cumulative_feature_importance
106
107
  return unless @feature_importances
107
108
 
108
109
  project_name = model.adapter.get_wandb_project
@@ -126,13 +127,16 @@ module EasyML
126
127
  "feature_importance" => bar_plot.__pyptr__,
127
128
  }
128
129
  Wandb.log(log_data)
129
- model.adapter.delete_wandb_project if finish
130
- Wandb.finish if finish
131
130
  end
132
131
 
133
132
  def after_tuning
134
133
  track_cumulative_feature_importance
135
134
  end
135
+
136
+ def cleanup
137
+ model.adapter.delete_wandb_project
138
+ Wandb.finish
139
+ end
136
140
  end
137
141
  end
138
142
  end
@@ -135,6 +135,12 @@ module EasyML
135
135
  end
136
136
  end
137
137
 
138
+ def cleanup
139
+ model.callbacks.each do |callback|
140
+ callback.cleanup if callback.respond_to?(:cleanup)
141
+ end
142
+ end
143
+
138
144
  def prepare_callbacks(tuner)
139
145
  set_wandb_project(tuner.project_name)
140
146
 
@@ -421,11 +427,11 @@ module EasyML
421
427
  def prepare_data
422
428
  if @d_train.nil?
423
429
  col_order = dataset.col_order
424
- x_sample, y_sample = dataset.train(split_ys: true, limit: 5, select: col_order)
430
+ x_sample, y_sample = dataset.train(split_ys: true, limit: 5, select: col_order, lazy: true)
425
431
  preprocess(x_sample, y_sample) # Ensure we fail fast if the dataset is misconfigured
426
- x_train, y_train = dataset.train(split_ys: true, select: col_order)
427
- x_valid, y_valid = dataset.valid(split_ys: true, select: col_order)
428
- x_test, y_test = dataset.test(split_ys: true, select: col_order)
432
+ x_train, y_train = dataset.train(split_ys: true, select: col_order, lazy: true)
433
+ x_valid, y_valid = dataset.valid(split_ys: true, select: col_order, lazy: true)
434
+ x_test, y_test = dataset.test(split_ys: true, select: col_order, lazy: true)
429
435
  @d_train = preprocess(x_train, y_train)
430
436
  @d_valid = preprocess(x_valid, y_valid)
431
437
  @d_test = preprocess(x_test, y_test)
@@ -434,21 +440,60 @@ module EasyML
434
440
  [@d_train, @d_valid, @d_test]
435
441
  end
436
442
 
443
+ def trainable?
444
+ untrainable_columns.empty?
445
+ end
446
+
447
+ def untrainable_columns
448
+ model.dataset.refresh if model.dataset.processed.nil?
449
+
450
+ model.dataset.list_nulls(
451
+ model.dataset.processed.data(lazy: true)
452
+ )
453
+ end
454
+
437
455
  def preprocess(xs, ys = nil)
438
456
  return xs if xs.is_a?(::XGBoost::DMatrix)
457
+ lazy = xs.is_a?(Polars::LazyFrame)
458
+ return xs if (lazy ? xs.limit(1).collect : xs).shape[0] == 0
459
+
460
+ weights_col = model.weights_column || nil
461
+
462
+ if weights_col == model.dataset.target
463
+ raise ArgumentError, "Weight column cannot be the target column"
464
+ end
465
+
466
+ # Extract feature columns (all columns except label and weight)
467
+ feature_cols = xs.columns
468
+ feature_cols -= [weights_col] if weights_col
469
+
470
+ # Get features, labels and weights
471
+ begin
472
+ features = lazy ? xs.select(feature_cols).collect.to_numo : xs.select(feature_cols).to_numo
473
+ rescue => e
474
+ binding.pry
475
+ end
476
+ weights = weights_col ? (lazy ? xs.select(weights_col).collect.to_numo : xs.select(weights_col).to_numo) : nil
477
+ weights = weights.flatten if weights
478
+ if ys.present?
479
+ ys = ys.is_a?(Array) ? Polars::Series.new(ys) : ys
480
+ labels = lazy ? ys.collect.to_numo.flatten : ys.to_numo.flatten
481
+ else
482
+ labels = nil
483
+ end
484
+
485
+ kwargs = {
486
+ label: labels,
487
+ weight: weights,
488
+ }.compact
439
489
 
440
- orig_xs = xs.dup
441
- column_names = xs.columns
442
- xs = _preprocess(xs)
443
- ys = ys.nil? ? nil : _preprocess(ys).flatten
444
- kwargs = { label: ys }.compact
445
490
  begin
446
- ::XGBoost::DMatrix.new(xs, **kwargs).tap do |dmat|
447
- dmat.feature_names = column_names
491
+ ::XGBoost::DMatrix.new(features, **kwargs).tap do |dmatrix|
492
+ dmatrix.feature_names = feature_cols
448
493
  end
449
494
  rescue StandardError => e
450
- problematic_columns = orig_xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
451
- problematic_xs = orig_xs.select(problematic_columns.keys)
495
+ problematic_columns = xs.schema.select { |k, v| [Polars::Categorical, Polars::String].include?(v) }
496
+ problematic_xs = lazy ? xs.lazy.select(problematic_columns.keys).collect : xs.select(problematic_columns.keys)
452
497
  raise %(
453
498
  Error building data for XGBoost.
454
499
  Apply preprocessing to columns
@@ -501,29 +546,6 @@ module EasyML
501
546
  cb_container.after_iteration(@booster, current_iteration, d_train, evals)
502
547
  end
503
548
 
504
- def _preprocess(df)
505
- return df if df.is_a?(Array)
506
-
507
- df.to_a.map do |row|
508
- row.values.map do |value|
509
- case value
510
- when Time
511
- value.to_i # Convert Time to Unix timestamp
512
- when Date
513
- value.to_time.to_i # Convert Date to Unix timestamp
514
- when String
515
- value
516
- when TrueClass, FalseClass
517
- value ? 1.0 : 0.0 # Convert booleans to 1.0 and 0.0
518
- when Integer
519
- value
520
- else
521
- value.to_f # Ensure everything else is converted to a float
522
- end
523
- end
524
- end
525
- end
526
-
527
549
  def initialize_model
528
550
  @xgboost_model = model_class.new(n_estimators: @hyperparameters.to_h.dig(:n_estimators))
529
551
  if block_given?
@@ -150,7 +150,7 @@ module EasyML
150
150
 
151
151
  training_model.dataset.refresh
152
152
  evaluator = retraining_job.evaluator.symbolize_keys
153
- x_test, y_test = training_model.dataset.test(split_ys: true)
153
+ x_test, y_test = training_model.dataset.test(split_ys: true, all_columns: true)
154
154
  y_pred = training_model.predict(x_test)
155
155
 
156
156
  metric = evaluator[:metric].to_sym
@@ -27,6 +27,7 @@ module EasyML
27
27
  :model_type,
28
28
  :task,
29
29
  :objective,
30
+ :weights_column,
30
31
  :metrics,
31
32
  :dataset_id,
32
33
  :status,
@@ -18,12 +18,22 @@ module EasyML
18
18
  end
19
19
 
20
20
  def defaults
21
- {}
21
+ model.adapter.hyperparameters.class.hyperparameter_constants.transform_values do |constant|
22
+ values = constant.slice(:min, :max, :step, :options)
23
+ if values.key?(:options)
24
+ values[:options] = values[:options].map { |option| option[:value] }
25
+ end
26
+ values
27
+ end
22
28
  end
23
29
 
24
30
  def run_trial(trial)
25
31
  config = deep_merge_defaults(self.config.clone.deep_symbolize_keys)
26
- suggest_parameters(trial, config)
32
+ # For first trial, re-use the original hyperparameters, so they
33
+ # serve as our starting point/imputers
34
+ unless trial == 1
35
+ suggest_parameters(trial, config)
36
+ end
27
37
  yield model
28
38
  end
29
39
 
@@ -57,8 +67,11 @@ module EasyML
57
67
  min = param_config[:min]
58
68
  max = param_config[:max]
59
69
  log = param_config[:log]
70
+ options = param_config[:options]
60
71
 
61
- if log
72
+ if options
73
+ trial.suggest_categorical(param_name.to_s, options)
74
+ elsif log
62
75
  trial.suggest_loguniform(param_name.to_s, min, max)
63
76
  elsif max.is_a?(Integer) && min.is_a?(Integer)
64
77
  trial.suggest_int(param_name.to_s, min, max)
@@ -5,23 +5,6 @@ module EasyML
5
5
  class Tuner
6
6
  module Adapters
7
7
  class XGBoostAdapter < BaseAdapter
8
- def defaults
9
- {
10
- learning_rate: {
11
- min: 0.001,
12
- max: 0.1,
13
- log: true,
14
- },
15
- n_estimators: {
16
- min: 100,
17
- max: 1_000,
18
- },
19
- max_depth: {
20
- min: 2,
21
- max: 20,
22
- },
23
- }
24
- end
25
8
  end
26
9
  end
27
10
  end