easy_ml 0.2.0.pre.rc84 → 0.2.0.pre.rc88

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. checksums.yaml +4 -4
  2. data/app/controllers/easy_ml/datasets_controller.rb +19 -3
  3. data/app/frontend/components/dataset/PreprocessingConfig.tsx +523 -150
  4. data/app/frontend/types/dataset.ts +5 -2
  5. data/app/models/easy_ml/column/imputers/base.rb +23 -2
  6. data/app/models/easy_ml/column/imputers/embedding_encoder.rb +18 -0
  7. data/app/models/easy_ml/column/imputers/imputer.rb +1 -0
  8. data/app/models/easy_ml/column/imputers/most_frequent.rb +1 -1
  9. data/app/models/easy_ml/column/imputers/one_hot_encoder.rb +1 -1
  10. data/app/models/easy_ml/column/imputers/ordinal_encoder.rb +1 -1
  11. data/app/models/easy_ml/column/imputers.rb +47 -41
  12. data/app/models/easy_ml/column/selector.rb +2 -2
  13. data/app/models/easy_ml/column.rb +260 -56
  14. data/app/models/easy_ml/column_history.rb +6 -0
  15. data/app/models/easy_ml/column_list.rb +30 -1
  16. data/app/models/easy_ml/dataset/learner/lazy/embedding.rb +10 -0
  17. data/app/models/easy_ml/dataset/learner/lazy/query.rb +2 -0
  18. data/app/models/easy_ml/dataset/learner.rb +11 -0
  19. data/app/models/easy_ml/dataset.rb +6 -19
  20. data/app/models/easy_ml/lineage_history.rb +17 -0
  21. data/app/models/easy_ml/model.rb +11 -1
  22. data/app/models/easy_ml/models/xgboost.rb +37 -7
  23. data/app/models/easy_ml/pca_model.rb +21 -0
  24. data/app/models/easy_ml/prediction.rb +2 -1
  25. data/app/serializers/easy_ml/column_serializer.rb +13 -1
  26. data/config/initializers/inflections.rb +1 -0
  27. data/lib/easy_ml/data/dataset_manager/writer/append_only.rb +6 -8
  28. data/lib/easy_ml/data/dataset_manager/writer/base.rb +15 -2
  29. data/lib/easy_ml/data/dataset_manager/writer/partitioned.rb +0 -1
  30. data/lib/easy_ml/data/dataset_manager/writer.rb +2 -0
  31. data/lib/easy_ml/data/embeddings/compressor.rb +179 -0
  32. data/lib/easy_ml/data/embeddings/embedder.rb +226 -0
  33. data/lib/easy_ml/data/embeddings.rb +61 -0
  34. data/lib/easy_ml/data/polars_column.rb +3 -0
  35. data/lib/easy_ml/data/polars_reader.rb +54 -23
  36. data/lib/easy_ml/data/polars_schema.rb +28 -2
  37. data/lib/easy_ml/data/splits/file_split.rb +7 -2
  38. data/lib/easy_ml/data.rb +1 -0
  39. data/lib/easy_ml/embedding_store.rb +92 -0
  40. data/lib/easy_ml/engine.rb +4 -2
  41. data/lib/easy_ml/predict.rb +42 -20
  42. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +5 -0
  43. data/lib/easy_ml/railtie/templates/migration/add_is_primary_key_to_easy_ml_columns.rb.tt +9 -0
  44. data/lib/easy_ml/railtie/templates/migration/add_metadata_to_easy_ml_predictions.rb.tt +6 -0
  45. data/lib/easy_ml/railtie/templates/migration/add_pca_model_id_to_easy_ml_columns.rb.tt +9 -0
  46. data/lib/easy_ml/railtie/templates/migration/add_workflow_status_to_easy_ml_dataset_histories.rb.tt +13 -0
  47. data/lib/easy_ml/railtie/templates/migration/create_easy_ml_pca_models.rb.tt +14 -0
  48. data/lib/easy_ml/version.rb +1 -1
  49. data/lib/easy_ml.rb +1 -0
  50. data/public/easy_ml/assets/.vite/manifest.json +2 -2
  51. data/public/easy_ml/assets/assets/Application-DfPoyRr8.css +1 -0
  52. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-KENNRQpC.js +533 -0
  53. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-KENNRQpC.js.map +1 -0
  54. metadata +59 -6
  55. data/lib/tasks/profile.rake +0 -40
  56. data/public/easy_ml/assets/assets/Application-nnn_XLuL.css +0 -1
  57. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-Bbf3mD_b.js +0 -522
  58. data/public/easy_ml/assets/assets/entrypoints/Application.tsx-Bbf3mD_b.js.map +0 -1
@@ -25,6 +25,8 @@
25
25
  # last_datasource_sha :string
26
26
  # last_feature_sha :string
27
27
  # in_raw_dataset :boolean
28
+ # is_primary_key :boolean
29
+ # pca_model_id :integer
28
30
  #
29
31
  module EasyML
30
32
  class Column < ActiveRecord::Base
@@ -36,16 +38,19 @@ module EasyML
36
38
 
37
39
  belongs_to :dataset, class_name: "EasyML::Dataset"
38
40
  belongs_to :feature, class_name: "EasyML::Feature", optional: true
41
+ belongs_to :pca_model, class_name: "EasyML::PCAModel", optional: true
39
42
  has_many :lineages, class_name: "EasyML::Lineage"
40
43
 
41
44
  validates :name, presence: true
42
45
  validates :name, uniqueness: { scope: :dataset_id }
43
46
 
44
47
  before_save :ensure_valid_datatype
45
- after_save :handle_date_column_change
48
+ before_save :ensure_valid_encoding
49
+ after_save :handle_unique_attrs
46
50
  before_save :set_defaults
47
51
  before_save :set_feature_lineage
48
52
  before_save :set_polars_datatype
53
+ before_save :ensure_cast_works
49
54
  # after_find :ensure_feature_exists
50
55
 
51
56
  # Scopes
@@ -128,13 +133,23 @@ module EasyML
128
133
  "#<#{self.class.name} #{display_attributes.map { |k, v| "#{k}: #{v}" }.join(", ")}>"
129
134
  end
130
135
 
136
+ def processed_columns
137
+ has_virtual_columns? ? virtual_columns : [name]
138
+ end
139
+
131
140
  def aliases
132
141
  [name].concat(virtual_columns)
133
142
  end
134
143
 
144
+ def has_virtual_columns?
145
+ one_hot? || embedded?
146
+ end
147
+
135
148
  def virtual_columns
136
149
  if one_hot?
137
150
  allowed_categories.map { |cat| "#{name}_#{cat}" }
151
+ elsif embedded?
152
+ ["#{name}_embedding"]
138
153
  else
139
154
  []
140
155
  end
@@ -237,6 +252,34 @@ module EasyML
237
252
  end
238
253
  end
239
254
 
255
+ VALID_ENCODINGS = [:one_hot, :ordinal, :embedding].freeze
256
+
257
+ def encoding
258
+ preprocessing_steps.deep_symbolize_keys.dig(:training, :encoding)&.to_sym
259
+ end
260
+
261
+ def one_hot?
262
+ encoding == :one_hot
263
+ end
264
+
265
+ def ordinal_encoding?
266
+ encoding == :ordinal
267
+ end
268
+
269
+ def embedded?
270
+ encoding == :embedding
271
+ end
272
+
273
+ def encoding_applies?(encoding_type)
274
+ encoding == encoding_type
275
+ end
276
+
277
+ def validate_encoding
278
+ return true if encoding.nil?
279
+ return false unless VALID_ENCODINGS.include?(encoding)
280
+ true
281
+ end
282
+
240
283
  EasyML::Data::PolarsColumn::TYPE_MAP.keys.each do |dtype|
241
284
  define_method("#{dtype}?") do
242
285
  datatype.to_s == dtype.to_s
@@ -341,18 +384,52 @@ module EasyML
341
384
  (read_attribute(:preprocessing_steps) || {}).symbolize_keys
342
385
  end
343
386
 
344
- def one_hot?
345
- preprocessing_steps.deep_symbolize_keys.dig(:training, :params, :one_hot) == true
387
+ def embedding_column
388
+ return nil unless embedded?
389
+ virtual_columns.first
346
390
  end
347
391
 
348
- def ordinal_encoding?
349
- preprocessing_steps.deep_symbolize_keys.dig(:training, :params, :ordinal_encoding) == true
392
+ def embed(df, fit: false)
393
+ return df if df.columns.include?(embedding_column) && df.filter(Polars.col(embedding_column).is_null).empty?
394
+ return df unless name.present?
395
+ return df unless embedded?
396
+
397
+ if fit
398
+ pca_model = get_pca_model
399
+ return if pca_model.fit_at.present? && pca_model.fit_at > dataset.datasource.refreshed_at && !pca_model_outdated?
400
+ end
401
+
402
+ actually_generate_embeddings(df, fit: fit)
403
+
404
+ df = decorate_embeddings(df, compressed: true)
405
+ df
350
406
  end
351
407
 
352
- def encoding
353
- return nil unless categorical?
354
- return :ordinal if ordinal_encoding?
355
- return :one_hot
408
+ def store_embeddings(df, compressed: false)
409
+ return unless embedded?
410
+
411
+ embedding_store.store(df, compressed: compressed)
412
+ end
413
+
414
+ def embedding_config
415
+ return nil unless embedded?
416
+ preprocessing_steps = self.preprocessing_steps.deep_symbolize_keys
417
+
418
+ preprocessing_steps.dig(:training, :params).slice(:llm, :preset, :dimensions).merge!(
419
+ column: self.name,
420
+ output_column: "#{self.name}_embedding",
421
+ config: {
422
+ default_options: {
423
+ embeddings_model_name: preprocessing_steps.deep_symbolize_keys.dig(:training, :params, :model),
424
+ },
425
+ },
426
+ )
427
+ end
428
+
429
+ def embedding_store
430
+ return nil unless embedded?
431
+
432
+ @embedding_store ||= EasyML::EmbeddingStore.new(self)
356
433
  end
357
434
 
358
435
  def categorical_min
@@ -458,94 +535,212 @@ module EasyML
458
535
  value
459
536
  end
460
537
 
538
+ def get_pca_model
539
+ pca_model || build_pca_model
540
+ end
541
+
542
+ def n_dimensions
543
+ return nil unless embedded?
544
+
545
+ preprocessing_steps.deep_symbolize_keys.dig(:training, :params, :dimensions)
546
+ end
547
+
461
548
  private
462
549
 
550
+ def ensure_cast_works
551
+ begin
552
+ raw.data(cast: { name => polars_datatype })
553
+ rescue => e
554
+ raw_dtype = EasyML::Data::PolarsColumn.polars_to_sym(
555
+ raw.data(cast: false, limit: 1).schema[name]
556
+ )
557
+ errors.add(:datatype, "Can't cast from #{raw_dtype} to #{datatype}")
558
+ end
559
+ end
560
+
561
+ def pca_model_outdated?
562
+ return false unless EasyML::Data::Embeddings::Compressor::COMPRESSION_ENABLED
563
+
564
+ pca_model = get_pca_model
565
+ return false unless pca_model.persisted?
566
+ return false unless n_dimensions.present?
567
+
568
+ pca_model.model.params.dig(:n_components) != n_dimensions
569
+ end
570
+
571
+ def needs_embed(df, compressed: false)
572
+ if df.columns.exclude?(name)
573
+ Polars::DataFrame.new
574
+ elsif embedding_store.empty?(compressed: compressed)
575
+ df
576
+ else
577
+ stored_embeddings = embedding_store.query(lazy: true, compressed: compressed)
578
+ df.filter(Polars.col(name).is_null.not_).join(
579
+ stored_embeddings.select(name),
580
+ on: name,
581
+ how: "anti",
582
+ )
583
+ end
584
+ end
585
+
586
+ def decorate_embeddings(df, compressed: false)
587
+ if df.columns.include?(embedding_column)
588
+ orig_col_order = df.columns
589
+ df = df.drop(embedding_column) if df.columns.include?(embedding_column)
590
+ else
591
+ orig_col_order = df.columns + [embedding_column]
592
+ end
593
+
594
+ df = df.join(
595
+ embedding_store.query(lazy: true, compressed: compressed),
596
+ on: name,
597
+ how: "left",
598
+ ).select(orig_col_order)
599
+ df
600
+ end
601
+
602
+ def embed_and_compress(df, fit: false)
603
+ needs_embed = self.needs_embed(df, compressed: false)
604
+ needs_recompress = fit && pca_model_outdated?
605
+
606
+ extra_params = {
607
+ df: needs_embed,
608
+ pca_model: fit ? nil : get_pca_model.model,
609
+ }.compact
610
+ generator = EasyML::Data::Embeddings.new(embedding_config.merge!(extra_params))
611
+
612
+ if needs_embed.shape[0] > 0
613
+ needs_embed = generator.embed
614
+ store_embeddings(needs_embed, compressed: false)
615
+ end
616
+
617
+ # When the PCA model is outdated, we need to re-fit the PCA model and re-compress,
618
+ # but we don't need to re-generate the full embeddings again
619
+ if needs_recompress
620
+ needs_embed = decorate_embeddings(df.clone, compressed: false)
621
+ embedding_store.compressed_store.wipe
622
+ end
623
+
624
+ needs_embed = self.needs_embed(df, compressed: true)
625
+ return df if needs_embed.empty?
626
+ if needs_embed.columns.exclude?(embedding_column) || ((needs_embed.shape[0] == 1) && needs_embed.filter(Polars.col(embedding_column).is_null).count == 1)
627
+ needs_embed = decorate_embeddings(needs_embed, compressed: false)
628
+ end
629
+
630
+ if (n_dimensions.present? && needs_embed.shape[1] > 0 && n_dimensions < needs_embed[embedding_column][0].count)
631
+ compressed = generator.compress(needs_embed, fit: fit)
632
+ store_embeddings(compressed, compressed: true)
633
+ else
634
+ store_embeddings(needs_embed, compressed: true)
635
+ end
636
+
637
+ if fit
638
+ embedding_store.compact
639
+
640
+ get_pca_model.update(
641
+ model: generator.pca_model,
642
+ fit_at: Time.now,
643
+ )
644
+ update(pca_model_id: get_pca_model.id)
645
+ end
646
+ end
647
+
648
+ def actually_generate_embeddings(df, fit: false)
649
+ return df if df.empty?
650
+
651
+ embed_and_compress(df, fit: fit)
652
+ end
653
+
463
654
  def set_defaults
464
655
  self.preprocessing_steps = set_preprocessing_steps_defaults
465
656
  end
466
657
 
467
658
  def set_preprocessing_steps_defaults
468
- preprocessing_steps.inject({}) do |h, (type, config)|
659
+ preprocessing_steps.deep_symbolize_keys.inject({}) do |h, (type, config)|
469
660
  h.tap do
470
661
  h[type] = set_preprocessing_step_defaults(config)
471
662
  end
472
663
  end
473
664
  end
474
665
 
475
- ALLOWED_PARAMS = {
476
- constant: [:constant],
477
- categorical: %i[categorical_min one_hot ordinal_encoding],
478
- most_frequent: %i[one_hot ordinal_encoding],
479
- mean: [:clip],
480
- median: [:clip],
481
- }
482
-
483
666
  REQUIRED_PARAMS = {
667
+ embedding: %i[llm model],
484
668
  constant: [:constant],
485
- categorical: %i[categorical_min one_hot ordinal_encoding],
669
+ categorical: %i[categorical_min],
486
670
  }
487
671
 
488
672
  DEFAULT_PARAMS = {
489
673
  categorical_min: 1,
490
- one_hot: true,
491
- ordinal_encoding: false,
492
674
  clip: { min: 0, max: 1_000_000_000 },
493
675
  constant: nil,
676
+ llm: "openai",
677
+ model: "text-embedding-3-small",
678
+ preset: :full,
494
679
  }
495
680
 
496
- XOR_PARAMS = [{
497
- params: [:one_hot, :ordinal_encoding],
498
- default: :one_hot,
499
- }]
500
-
501
681
  def set_preprocessing_step_defaults(config)
502
682
  config.deep_symbolize_keys!
503
683
  config[:params] ||= {}
504
- params = config[:params].symbolize_keys
684
+ params = config[:params].deep_symbolize_keys
505
685
 
506
686
  required = REQUIRED_PARAMS.fetch(config[:method].to_sym, [])
507
- allowed = ALLOWED_PARAMS.fetch(config[:method].to_sym, [])
508
687
 
509
688
  missing = required - params.keys
510
- missing.reject! do |param|
511
- XOR_PARAMS.any? do |rule|
512
- if rule[:params].include?(param)
513
- missing_param = rule[:params].find { |p| p != param }
514
- params[missing_param] == true
515
- else
516
- false
517
- end
518
- end
519
- end
520
- extra = params.keys - allowed
521
-
522
689
  missing.each do |key|
523
690
  params[key] = DEFAULT_PARAMS.fetch(key)
524
691
  end
525
692
 
526
- extra.each do |key|
527
- params.delete(key)
528
- end
693
+ config.merge!(params: params)
694
+ end
529
695
 
530
- # Only set one of one_hot or ordinal_encoding to true,
531
- # by default set one_hot to true
532
- xor = XOR_PARAMS.find { |rule| rule[:params] & params.keys == rule[:params] }
533
- if xor && xor[:params].all? { |param| params[param] }
534
- xor[:params].each { |param| params[param] = false }
535
- params[xor[:default]] = true
696
+ def handle_unique_attrs
697
+ return unless primary_key_changed? || target_changed? || is_date_column_changed?
698
+
699
+ Column.transaction do
700
+ handle_date_column_change
701
+ handle_primary_key_change
702
+ handle_target_change
703
+ resync_dataset if dataset.processed_schema.present? # When using Import, columns are created before the dataset
536
704
  end
705
+ end
537
706
 
538
- config.merge!(params: params)
707
+ def target_changed?
708
+ saved_change_to_is_target? && is_target?
709
+ end
710
+
711
+ def primary_key_changed?
712
+ saved_change_to_is_primary_key? && is_primary_key?
713
+ end
714
+
715
+ def is_date_column_changed?
716
+ saved_change_to_is_date_column? && is_date_column?
717
+ end
718
+
719
+ def handle_target_change
720
+ return unless target_changed?
721
+
722
+ dataset.columns.where.not(id: id).update_all(is_target: false)
723
+ end
724
+
725
+ def primary_key_changed?
726
+ saved_change_to_is_primary_key? && is_primary_key?
727
+ end
728
+
729
+ def handle_primary_key_change
730
+ return unless primary_key_changed?
731
+
732
+ dataset.columns.where.not(id: id).update_all(is_primary_key: false)
539
733
  end
540
734
 
541
735
  def handle_date_column_change
542
- return unless saved_change_to_is_date_column? && is_date_column?
736
+ return unless is_date_column_changed?
543
737
 
544
- Column.transaction do
545
- dataset.columns.where.not(id: id).update_all(is_date_column: false)
546
- dataset.learn_statistics
547
- dataset.columns.sync
548
- end
738
+ dataset.columns.where.not(id: id).update_all(is_date_column: false)
739
+ end
740
+
741
+ def resync_dataset
742
+ dataset.learn_statistics
743
+ dataset.columns.sync
549
744
  end
550
745
 
551
746
  def ensure_valid_datatype
@@ -557,6 +752,15 @@ module EasyML
557
752
  throw :abort
558
753
  end
559
754
 
755
+ def ensure_valid_encoding
756
+ return true if encoding.nil?
757
+
758
+ unless VALID_ENCODINGS.include?(encoding)
759
+ errors.add(:encoding, "must be one of: #{VALID_ENCODINGS.join(", ")}")
760
+ throw(:abort)
761
+ end
762
+ end
763
+
560
764
  NUMERIC_METHODS = %i[mean median].freeze
561
765
 
562
766
  def data_selector
@@ -30,6 +30,8 @@
30
30
  # last_datasource_sha :string
31
31
  # last_feature_sha :string
32
32
  # in_raw_dataset :boolean
33
+ # is_primary_key :boolean
34
+ # pca_model_id :integer
33
35
  #
34
36
  module EasyML
35
37
  class ColumnHistory < ActiveRecord::Base
@@ -38,5 +40,9 @@ module EasyML
38
40
  scope :required, -> { where(is_computed: false, hidden: false, is_target: false).where("preprocessing_steps IS NULL OR preprocessing_steps::text = '{}'::text") }
39
41
  scope :computed, -> { where(is_computed: true) }
40
42
  scope :raw, -> { where(is_computed: false) }
43
+
44
+ def get_pca_model
45
+ column.pca_model
46
+ end
41
47
  end
42
48
  end
@@ -84,15 +84,44 @@ module EasyML
84
84
  end
85
85
  end
86
86
 
87
+ def col_order(inference: false)
88
+ scope = reject(&:hidden)
89
+ scope = scope.reject(&:is_target) if inference
90
+ scope.flat_map do |col|
91
+ col.processed_columns.map do |name|
92
+ [col.id, name]
93
+ end
94
+ end.sort.map { |arr| arr[1] }.uniq
95
+ end
96
+
97
+ def cast(processed_or_raw)
98
+ columns = where(is_computed: false)
99
+ is_processed = processed_or_raw == :processed
100
+ columns.reduce({}) do |h, col|
101
+ h.tap do
102
+ dtype = (col.ordinal_encoding? && is_processed) ? nil : col.read_attribute(:polars_datatype)
103
+ next if dtype.nil? || dtype.blank?
104
+
105
+ h[col.name] = dtype.constantize
106
+ end
107
+ end.compact
108
+ end
109
+
87
110
  def one_hot?(column)
88
111
  one_hots.map(&:name).detect do |one_hot_col|
89
112
  column.start_with?(one_hot_col)
90
113
  end
91
114
  end
92
115
 
116
+ def embedded?(column)
117
+ column_list.select(&:embedded?).detect do |col|
118
+ column == col.embedding_column
119
+ end
120
+ end
121
+
93
122
  def syncable
94
123
  dataset.processed_schema.keys.select do |col|
95
- !one_hot?(col)
124
+ !one_hot?(col) && !embedded?(col)
96
125
  end
97
126
  end
98
127
 
@@ -0,0 +1,10 @@
1
+ module EasyML
2
+ class Dataset
3
+ class Learner
4
+ class Lazy
5
+ class Embedding < Query
6
+ end
7
+ end
8
+ end
9
+ end
10
+ end
@@ -15,6 +15,8 @@ module EasyML
15
15
  Lazy::Datetime
16
16
  when :boolean
17
17
  Lazy::Boolean
18
+ when :embedding
19
+ Lazy::Embedding
18
20
  when :null
19
21
  Lazy::Null
20
22
  else
@@ -18,12 +18,23 @@ module EasyML
18
18
 
19
19
  def learn
20
20
  prepare
21
+ fit_models
21
22
  learn_statistics
22
23
  save_statistics
23
24
  end
24
25
 
25
26
  private
26
27
 
28
+ def fit_models
29
+ fit_embedding_models
30
+ end
31
+
32
+ def fit_embedding_models
33
+ columns.select(&:embedded?).each do |col|
34
+ col.embed(dataset.train(all_columns: true), fit: true)
35
+ end
36
+ end
37
+
27
38
  def save_statistics
28
39
  columns.each do |col|
29
40
  col.merge_statistics(statistics.dig(col.name))
@@ -84,6 +84,7 @@ module EasyML
84
84
  preprocessing_strategies: EasyML::Column::Imputers.constants[:preprocessing_strategies],
85
85
  feature_options: EasyML::Features::Registry.list_flat,
86
86
  splitter_constants: EasyML::Splitter.constants,
87
+ embedding_constants: EasyML::Data::Embeddings::Embedder.constants,
87
88
  }
88
89
  end
89
90
 
@@ -393,7 +394,6 @@ module EasyML
393
394
  end
394
395
 
395
396
  def lock_dataset
396
- data = processed.data(limit: 1).to_a.any? ? processed.data : raw.data
397
397
  with_lock do |client|
398
398
  yield
399
399
  end
@@ -600,6 +600,10 @@ module EasyML
600
600
  }.compact.deep_symbolize_keys
601
601
  end
602
602
 
603
+ def dataset_primary_key
604
+ @dataset_primary_key ||= preloaded_columns.find(&:is_primary_key)&.name
605
+ end
606
+
603
607
  def target
604
608
  @target ||= preloaded_columns.find(&:is_target)&.name
605
609
  end
@@ -617,24 +621,7 @@ module EasyML
617
621
  end
618
622
 
619
623
  def col_order(inference: false)
620
- # Filter preloaded columns in memory
621
- scope = preloaded_columns.reject(&:hidden)
622
- scope = scope.reject(&:is_target) if inference
623
-
624
- # Get one_hot columns for category mapping
625
- one_hots = scope.select(&:one_hot?)
626
- one_hot_cats = columns.allowed_categories.symbolize_keys
627
-
628
- # Map columns to names, handling one_hot expansion
629
- scope.flat_map do |col|
630
- if col.one_hot?
631
- one_hot_cats[col.name.to_sym].map do |cat|
632
- "#{col.name}_#{cat}"
633
- end
634
- else
635
- col.name
636
- end
637
- end.sort
624
+ preloaded_columns.col_order(inference: inference)
638
625
  end
639
626
 
640
627
  def column_mask(df, inference: false)
@@ -1,3 +1,20 @@
1
+ # == Schema Information
2
+ #
3
+ # Table name: easy_ml_lineage_histories
4
+ #
5
+ # id :bigint not null, primary key
6
+ # easy_ml_lineage_id :integer not null
7
+ # column_id :integer not null
8
+ # key :string not null
9
+ # description :string
10
+ # occurred_at :datetime
11
+ # created_at :datetime not null
12
+ # updated_at :datetime not null
13
+ # history_started_at :datetime not null
14
+ # history_ended_at :datetime
15
+ # history_user_id :integer
16
+ # snapshot_id :string
17
+ #
1
18
  module EasyML
2
19
  class LineageHistory < ActiveRecord::Base
3
20
  self.table_name = "easy_ml_lineage_histories"
@@ -296,14 +296,24 @@ module EasyML
296
296
  )
297
297
  end
298
298
 
299
- def predict(xs)
299
+ def prepare_predict(xs)
300
300
  load_model!
301
301
  unless xs.is_a?(XGBoost::DMatrix)
302
302
  xs = dataset.normalize(xs, inference: true)
303
303
  end
304
+ xs
305
+ end
306
+
307
+ def predict(xs)
308
+ xs = prepare_predict(xs)
304
309
  adapter.predict(xs)
305
310
  end
306
311
 
312
+ def predict_proba(xs)
313
+ xs = prepare_predict(xs)
314
+ adapter.predict_proba(xs)
315
+ end
316
+
307
317
  def save_model_file
308
318
  raise "No trained model! Need to train model before saving (call model.fit)" unless is_fit?
309
319
  return unless adapter.loaded?