easy_ml 0.2.0.pre.rc101 → 0.2.0.pre.rc103

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/app/controllers/easy_ml/datasets_controller.rb +1 -0
  3. data/app/frontend/components/dataset/PreprocessingConfig.tsx +0 -1
  4. data/app/frontend/components/dataset/splitters/types.ts +3 -4
  5. data/app/frontend/pages/NewDatasetPage.tsx +17 -0
  6. data/app/frontend/types/datasource.ts +14 -6
  7. data/app/models/easy_ml/column/imputers/base.rb +3 -1
  8. data/app/models/easy_ml/column.rb +26 -13
  9. data/app/models/easy_ml/column_list.rb +2 -2
  10. data/app/models/easy_ml/dataset/learner/lazy/datetime.rb +3 -1
  11. data/app/models/easy_ml/dataset/learner/lazy/numeric.rb +24 -5
  12. data/app/models/easy_ml/dataset/learner/lazy/query.rb +19 -7
  13. data/app/models/easy_ml/dataset/learner/lazy/string.rb +4 -1
  14. data/app/models/easy_ml/dataset/learner/lazy.rb +17 -4
  15. data/app/models/easy_ml/dataset.rb +47 -9
  16. data/app/models/easy_ml/dataset_history.rb +1 -0
  17. data/app/models/easy_ml/feature.rb +5 -13
  18. data/app/models/easy_ml/lineage.rb +2 -1
  19. data/app/models/easy_ml/models/xgboost/evals_callback.rb +1 -0
  20. data/app/models/easy_ml/models/xgboost.rb +8 -3
  21. data/app/models/easy_ml/prediction.rb +1 -1
  22. data/app/models/easy_ml/splitters/base_splitter.rb +4 -8
  23. data/app/models/easy_ml/splitters/date_splitter.rb +2 -1
  24. data/app/models/easy_ml/splitters/predefined_splitter.rb +8 -3
  25. data/lib/easy_ml/data/dataset_manager/schema/normalizer.rb +201 -0
  26. data/lib/easy_ml/data/dataset_manager/schema.rb +9 -0
  27. data/lib/easy_ml/data/dataset_manager.rb +5 -0
  28. data/lib/easy_ml/data/date_converter.rb +24 -165
  29. data/lib/easy_ml/data/polars_column.rb +4 -2
  30. data/lib/easy_ml/data/polars_reader.rb +5 -2
  31. data/lib/easy_ml/engine.rb +4 -0
  32. data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +1 -0
  33. data/lib/easy_ml/railtie/templates/migration/add_view_class_to_easy_ml_datasets.rb.tt +9 -0
  34. data/lib/easy_ml/version.rb +1 -1
  35. data/public/easy_ml/assets/.vite/manifest.json +1 -1
  36. data/public/easy_ml/assets/assets/entrypoints/{Application.tsx-BXwsBCuQ.js → Application.tsx-gkZ77wo8.js} +8 -8
  37. data/public/easy_ml/assets/assets/entrypoints/{Application.tsx-BXwsBCuQ.js.map → Application.tsx-gkZ77wo8.js.map} +1 -1
  38. metadata +7 -5
  39. data/lib/easy_ml/data/dataset_manager/normalizer.rb +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7f48937aea567de8e40bc34486c4ac945b860ca26654d8d3b06efa1c1d4a54f3
4
- data.tar.gz: 1abb8bb2e3f3ba8bb9c228d7a9691e8906ababa523ad0d7155cdafbd3ec62396
3
+ metadata.gz: ef3f840cce99d7205957fbb39a6b319a45035624dce2e4e10f681383cb088abf
4
+ data.tar.gz: e25100f792ad48cfa4feab7eb652a2d6c49bfc6e28f3bcb97c8150f9bdd1bfc5
5
5
  SHA512:
6
- metadata.gz: ef28fcb989d2934329e4da3c9a138d3fc7b4c9ae995d7ce021217f4507e24b17664d2dee9690b7a00105e804b41f490442e2b90f1d76f8c12d7ddca768ae43ba
7
- data.tar.gz: 3ba6f95ca3a660540e81a49c5eba84f530b606d89dc499e05aa288d26b90802dfc74b0b7615360002f1f36c1255f41ea6871378ffa1e93abc147b4f2a5c6ab0c
6
+ metadata.gz: 5f58395d392158d149db34ad5019a0e011164ca8d331846553e44e6564a291d88323ad0090c1c5ded60f696940b30949cba4e1a614fa9cd502e94372ef949707
7
+ data.tar.gz: 9497391351ad054308a985cc6b9e608f8dfef61be7417d66502cb11c26ca4f7825456b31aab010c016ac10feff791a8f7a01893743ecf11a26fabb9de7405b82
@@ -190,6 +190,7 @@ module EasyML
190
190
  :description,
191
191
  :datasource_id,
192
192
  :target,
193
+ :view_class,
193
194
  drop_cols: [],
194
195
  splitter_attributes: %i[
195
196
  splitter_type
@@ -1028,7 +1028,6 @@ export function PreprocessingConfig({
1028
1028
  label: strategy.label
1029
1029
  })) || [])
1030
1030
  ]}
1031
- options={constants.preprocessing_strategies[selectedType]}
1032
1031
  />
1033
1032
 
1034
1033
  {renderStrategySpecificInfo('training')}
@@ -1,12 +1,11 @@
1
- import type { ColumnType } from '../../../types/datasource';
1
+ import type { Constants } from '../../../types/datasource';
2
2
  import type { Datasource } from '../types/datasource';
3
3
 
4
4
  export type NewDatasetFormProps = {
5
5
  datasources: Datasource[];
6
- constants: {
7
- columns: ColumnType[];
8
- };
6
+ constants: Constants;
9
7
  }
8
+
10
9
  export type SplitterType =
11
10
  | 'date'
12
11
  | 'random'
@@ -78,6 +78,7 @@ export default function NewDatasetPage({ constants, datasources }: NewDatasetFor
78
78
  dataset: {
79
79
  name: '',
80
80
  datasource_id: '',
81
+ view_class: '',
81
82
  splitter_attributes: {
82
83
  splitter_type: selectedSplitterType,
83
84
  ...getDefaultConfig(selectedSplitterType)
@@ -249,6 +250,22 @@ export default function NewDatasetPage({ constants, datasources }: NewDatasetFor
249
250
  />
250
251
  </div>
251
252
 
253
+ <div>
254
+ <label
255
+ htmlFor="view_class"
256
+ className="block text-sm font-medium text-gray-700"
257
+ >
258
+ View Class
259
+ </label>
260
+ <SearchableSelect
261
+ value={formData.dataset.view_class}
262
+ onChange={(value) => setData('dataset.view_class', value)}
263
+ options={constants.available_views}
264
+ className="mt-1"
265
+ placeholder="Select a view class (optional)..."
266
+ />
267
+ </div>
268
+
252
269
  {selectedDatasource && (
253
270
  <div className={`rounded-lg p-4 ${
254
271
  selectedDatasource.sync_error
@@ -10,6 +10,19 @@ export interface Schema {
10
10
  [key: string]: ColumnType;
11
11
  }
12
12
 
13
+ export interface Constants {
14
+ column_types: Array<{ value: string; label: string }>;
15
+ preprocessing_strategies: any;
16
+ feature_options: any;
17
+ splitter_constants: any;
18
+ embedding_constants: any;
19
+ available_views: Array<{ value: string; label: string }>;
20
+ DATASOURCE_TYPES: Array<{ value: string; label: string; description: string }>;
21
+ s3: {
22
+ S3_REGIONS: Array<{ value: string; label: string }>;
23
+ };
24
+ }
25
+
13
26
  export interface Datasource {
14
27
  id: number;
15
28
  name: string;
@@ -23,10 +36,5 @@ export interface Datasource {
23
36
 
24
37
  export interface DatasourceFormProps {
25
38
  datasource?: Datasource;
26
- constants: {
27
- DATASOURCE_TYPES: Array<{ value: string; label: string; description: string }>;
28
- s3: {
29
- S3_REGIONS: Array<{ value: string; label: string }>;
30
- };
31
- };
39
+ constants: Constants;
32
40
  }
@@ -46,6 +46,8 @@ module EasyML
46
46
  end
47
47
 
48
48
  def param_applies?
49
+ return false unless params.present?
50
+
49
51
  params.keys.any? { |p| imputers_own_params.include?(p.to_sym) && params[p] != false }
50
52
  end
51
53
 
@@ -60,7 +62,7 @@ module EasyML
60
62
  end
61
63
 
62
64
  def imputers_own_params
63
- Imputers.params_by_class[self.class] || []
65
+ Imputers.params_by_class[self.class] || {}
64
66
  end
65
67
 
66
68
  def imputers_own_encodings
@@ -71,6 +71,7 @@ module EasyML
71
71
  scope :has_clip, -> { where("preprocessing_steps->'training'->>'params' IS NOT NULL AND preprocessing_steps->'training'->'params' @> jsonb_build_object('clip', jsonb_build_object())") }
72
72
  scope :needs_learn, -> {
73
73
  datasource_changed
74
+ .or(is_view)
74
75
  .or(feature_applied)
75
76
  .or(feature_changed)
76
77
  .or(column_changed)
@@ -88,6 +89,13 @@ module EasyML
88
89
  )
89
90
  }
90
91
 
92
+ scope :is_view, -> {
93
+ left_joins(dataset: :datasource)
94
+ .left_joins(:feature)
95
+ .where(
96
+ Dataset.arel_table[:view_class].not_eq(nil)
97
+ )
98
+ }
91
99
  scope :feature_changed, -> {
92
100
  where(feature_id: Feature.has_changes.map(&:id))
93
101
  }
@@ -514,27 +522,32 @@ module EasyML
514
522
  EasyML::Import::Column.from_config(config, dataset, action: action)
515
523
  end
516
524
 
517
- def cast_statement(df, df_col, expected_dtype)
518
- expected_dtype = expected_dtype.is_a?(Polars::DataType) ? expected_dtype.class : expected_dtype
519
- actual_type = df[df_col].dtype
525
+ def cast_statement(series = nil)
526
+ expected_dtype = polars_datatype
527
+ actual_type = series&.dtype || expected_dtype
528
+
529
+ return Polars.col(name).cast(expected_dtype).alias(name) if expected_dtype == actual_type
520
530
 
521
531
  cast_statement = case expected_dtype.to_s
522
- when "Polars::Boolean"
532
+ when /Polars::List/
533
+ # we should start tracking polars args so we can know what type of list it is
534
+ Polars.col(name)
535
+ when /Polars::Boolean/
523
536
  case actual_type.to_s
524
- when "Polars::Boolean"
525
- Polars.col(df_col).cast(expected_dtype)
526
- when "Polars::Utf8", "Polars::Categorical", "Polars::String"
527
- Polars.col(df_col).eq("true").cast(expected_dtype)
528
- when "Polars::Null"
529
- Polars.col(df_col)
537
+ when /Polars::Boolean/, /Polars::Int/
538
+ Polars.col(name).cast(expected_dtype)
539
+ when /Polars::Utf/, /Polars::Categorical/, /Polars::String/
540
+ Polars.col(name).eq("true").cast(expected_dtype)
541
+ when /Polars::Null/
542
+ Polars.col(name)
530
543
  else
531
- raise "Unexpected dtype: #{actual_type} for column: #{df_col}"
544
+ raise "Unexpected dtype: #{actual_type} for column: #{name}"
532
545
  end
533
546
  else
534
- Polars.col(df_col).cast(expected_dtype)
547
+ Polars.col(name).cast(expected_dtype, strict: false)
535
548
  end
536
549
 
537
- cast_statement.alias(df_col)
550
+ cast_statement.alias(name)
538
551
  end
539
552
 
540
553
  def cast(value)
@@ -101,10 +101,10 @@ module EasyML
101
101
  end
102
102
  cast_statements = (df.columns & schema.keys.map(&:to_s)).map do |df_col|
103
103
  db_col = column_index[df_col]
104
- expected_dtype = schema[df_col.to_sym]
105
- db_col.cast_statement(df, df_col, expected_dtype)
104
+ db_col.cast_statement(df[df_col])
106
105
  end
107
106
  df = df.with_columns(cast_statements)
107
+ df
108
108
  end
109
109
 
110
110
  def cast(processed_or_raw)
@@ -10,7 +10,9 @@ module EasyML
10
10
  end
11
11
 
12
12
  def unique_count
13
- Polars.col(column.name).n_unique.alias("#{column.name}__unique_count")
13
+ Polars.col(column.name)
14
+ .cast(column.polars_datatype)
15
+ .n_unique.alias("#{column.name}__unique_count")
14
16
  end
15
17
  end
16
18
  end
@@ -5,11 +5,30 @@ module EasyML
5
5
  class Numeric < Query
6
6
  def train_query
7
7
  super.concat([
8
- Polars.col(column.name).mean.alias("#{column.name}__mean"),
9
- Polars.col(column.name).median.alias("#{column.name}__median"),
10
- Polars.col(column.name).min.alias("#{column.name}__min"),
11
- Polars.col(column.name).max.alias("#{column.name}__max"),
12
- Polars.col(column.name).std.alias("#{column.name}__std"),
8
+ Polars.col(column.name)
9
+ .cast(column.polars_datatype)
10
+ .mean
11
+ .alias("#{column.name}__mean"),
12
+
13
+ Polars.col(column.name)
14
+ .cast(column.polars_datatype)
15
+ .median
16
+ .alias("#{column.name}__median"),
17
+
18
+ Polars.col(column.name)
19
+ .cast(column.polars_datatype)
20
+ .min
21
+ .alias("#{column.name}__min"),
22
+
23
+ Polars.col(column.name)
24
+ .cast(column.polars_datatype)
25
+ .max
26
+ .alias("#{column.name}__max"),
27
+
28
+ Polars.col(column.name)
29
+ .cast(column.polars_datatype)
30
+ .std
31
+ .alias("#{column.name}__std"),
13
32
  ])
14
33
  end
15
34
  end
@@ -44,25 +44,37 @@ module EasyML
44
44
  end
45
45
 
46
46
  def null_count
47
- Polars.col(column.name).null_count.alias("#{column.name}__null_count")
47
+ Polars.col(column.name)
48
+ .cast(column.polars_datatype)
49
+ .null_count
50
+ .alias("#{column.name}__null_count")
48
51
  end
49
52
 
50
53
  def num_rows
51
- Polars.col(column.name).len.alias("#{column.name}__num_rows")
54
+ Polars.col(column.name)
55
+ .cast(column.polars_datatype)
56
+ .len
57
+ .alias("#{column.name}__num_rows")
52
58
  end
53
59
 
54
60
  def most_frequent_value
55
- Polars.col(column.name).filter(Polars.col(column.name).is_not_null).mode.first.alias("#{column.name}__most_frequent_value")
61
+ Polars.col(column.name)
62
+ .cast(column.polars_datatype)
63
+ .filter(Polars.col(column.name).is_not_null)
64
+ .mode
65
+ .first
66
+ .alias("#{column.name}__most_frequent_value")
56
67
  end
57
68
 
58
69
  def last_value
59
70
  return unless dataset.date_column.present?
60
71
 
61
72
  Polars.col(column.name)
62
- .sort_by(dataset.date_column.name, reverse: true, nulls_last: true)
63
- .filter(Polars.col(column.name).is_not_null)
64
- .first
65
- .alias("#{column.name}__last_value")
73
+ .cast(column.polars_datatype)
74
+ .sort_by(dataset.date_column.name, reverse: true, nulls_last: true)
75
+ .filter(Polars.col(column.name).is_not_null)
76
+ .first
77
+ .alias("#{column.name}__last_value")
66
78
  end
67
79
  end
68
80
  end
@@ -10,7 +10,10 @@ module EasyML
10
10
  end
11
11
 
12
12
  def unique_count
13
- Polars.col(column.name).cast(:str).n_unique.alias("#{column.name}__unique_count")
13
+ Polars.col(column.name)
14
+ .cast(Polars::String)
15
+ .n_unique
16
+ .alias("#{column.name}__unique_count")
14
17
  end
15
18
  end
16
19
  end
@@ -22,9 +22,22 @@ module EasyML
22
22
  def run_queries(split, type)
23
23
  queries = build_queries(split, type)
24
24
 
25
- dataset.columns.apply_clip(
26
- @dataset.send(type).send(split, all_columns: true, lazy: true)
27
- ).select(queries).collect
25
+ begin
26
+ dataset.columns.apply_clip(
27
+ @dataset.send(type).send(split, all_columns: true, lazy: true)
28
+ )
29
+ .select(queries).collect
30
+ rescue => e
31
+ problematic_queries = queries.select { |query|
32
+ begin
33
+ dataset.send(type).send(split, all_columns: true, lazy: true).select([query]).collect
34
+ false
35
+ rescue => e
36
+ true
37
+ end
38
+ }
39
+ raise "Query failed for queries... likely due to wrong column datatype: #{problematic_queries.join("\n")}"
40
+ end
28
41
  end
29
42
 
30
43
  def get_column_statistics(query_results)
@@ -51,4 +64,4 @@ module EasyML
51
64
  end
52
65
  end
53
66
  end
54
- end
67
+ end
@@ -20,6 +20,7 @@
20
20
  # updated_at :datetime not null
21
21
  # last_datasource_sha :string
22
22
  # raw_schema :jsonb
23
+ # view_class :string
23
24
  #
24
25
  module EasyML
25
26
  class Dataset < ActiveRecord::Base
@@ -64,6 +65,7 @@ module EasyML
64
65
  reject_if: :all_blank
65
66
 
66
67
  validates :datasource, presence: true
68
+ validate :view_class_exists, if: -> { view_class.present? }
67
69
 
68
70
  add_configuration_attributes :remote_files
69
71
 
@@ -85,6 +87,10 @@ module EasyML
85
87
  feature_options: EasyML::Features::Registry.list_flat,
86
88
  splitter_constants: EasyML::Splitter.constants,
87
89
  embedding_constants: EasyML::Data::Embeddings::Embedder.constants,
90
+ available_views: Rails.root.join("app/datasets").glob("*.rb").map { |f|
91
+ name = f.basename(".rb").to_s.camelize
92
+ { value: name, label: name.titleize }
93
+ }
88
94
  }
89
95
  end
90
96
 
@@ -148,7 +154,7 @@ module EasyML
148
154
  return @schema if @schema
149
155
  return read_attribute(:schema) if @serializing
150
156
 
151
- schema = read_attribute(:schema) || datasource.schema || datasource.after_sync.schema
157
+ schema = read_attribute(:schema) || materialized_view&.schema || datasource.schema || datasource.after_sync.schema
152
158
  schema = set_schema(schema)
153
159
  @schema = EasyML::Data::PolarsSchema.deserialize(schema)
154
160
  end
@@ -157,7 +163,7 @@ module EasyML
157
163
  return @raw_schema if @raw_schema
158
164
  return read_attribute(:raw_schema) if @serializing
159
165
 
160
- raw_schema = read_attribute(:raw_schema) || datasource.schema || datasource.after_sync.schema
166
+ raw_schema = read_attribute(:raw_schema) || materialized_view&.schema || datasource.schema || datasource.after_sync.schema
161
167
  raw_schema = set_raw_schema(raw_schema)
162
168
  @raw_schema = EasyML::Data::PolarsSchema.deserialize(raw_schema)
163
169
  end
@@ -178,7 +184,12 @@ module EasyML
178
184
  if datasource&.num_rows.nil?
179
185
  datasource.after_sync
180
186
  end
181
- datasource&.num_rows
187
+
188
+ if materialized_view.present?
189
+ materialized_view.shape[0]
190
+ else
191
+ datasource&.num_rows
192
+ end
182
193
  end
183
194
 
184
195
  def abort!
@@ -234,6 +245,29 @@ module EasyML
234
245
  features.update_all(workflow_status: "ready")
235
246
  end
236
247
 
248
+ def view_class_exists
249
+ begin
250
+ view_class.constantize
251
+ rescue NameError
252
+ errors.add(:view_class, "must be a valid class name")
253
+ end
254
+ end
255
+
256
+ def materialize_view(df)
257
+ df
258
+ end
259
+
260
+ def materialized_view
261
+ return @materialized_view if @materialized_view
262
+
263
+ original_df = datasource.data
264
+ if view_class.present?
265
+ @materialized_view = view_class.constantize.new.materialize_view(original_df)
266
+ else
267
+ @materialized_view = materialize_view(original_df)
268
+ end
269
+ end
270
+
237
271
  def prepare!
238
272
  prepare_features
239
273
  cleanup
@@ -423,6 +457,7 @@ module EasyML
423
457
  end
424
458
 
425
459
  def needs_learn?
460
+ return true if view_class.present?
426
461
  return true if columns_need_refresh?
427
462
 
428
463
  never_learned = columns.none?
@@ -471,6 +506,7 @@ module EasyML
471
506
  def normalize(df = nil, split_ys: false, inference: false, all_columns: false, features: self.features)
472
507
  df = apply_missing_columns(df, inference: inference)
473
508
  df = transform_columns(df, inference: inference, encode: false)
509
+ df = apply_cast(df)
474
510
  df = apply_features(df, features, inference: inference)
475
511
  df = apply_cast(df) if inference
476
512
  df = transform_columns(df, inference: inference)
@@ -798,7 +834,8 @@ module EasyML
798
834
  df = df.clone
799
835
  df = apply_features(df)
800
836
  processed.save(:train, df)
801
- learn_statistics(type: :processed)
837
+ learn(delete: false)
838
+ learn_statistics(type: :processed, computed: true)
802
839
  processed.cleanup
803
840
  end
804
841
 
@@ -836,11 +873,12 @@ module EasyML
836
873
  return unless force || needs_refresh?
837
874
 
838
875
  cleanup
839
- splitter.split(datasource) do |train_df, valid_df, test_df|
840
- [:train, :valid, :test].zip([train_df, valid_df, test_df]).each do |segment, df|
841
- raw.save(segment, df)
842
- end
843
- end
876
+
877
+ train_df, valid_df, test_df = splitter.split(self)
878
+ raw.save(:train, train_df)
879
+ raw.save(:valid, valid_df)
880
+ raw.save(:test, test_df)
881
+
844
882
  raw_schema # Set if not already set
845
883
  end
846
884
 
@@ -25,6 +25,7 @@
25
25
  # snapshot_id :string
26
26
  # last_datasource_sha :string
27
27
  # raw_schema :jsonb
28
+ # view_class :string
28
29
  #
29
30
  module EasyML
30
31
  class DatasetHistory < ActiveRecord::Base
@@ -277,24 +277,16 @@ module EasyML
277
277
  feature.fit_batch(batch_args.merge!(batch_id: batch_id))
278
278
  rescue => e
279
279
  EasyML::Feature.transaction do
280
- return if dataset.reload.workflow_status == :failed
281
-
282
- feature.update(workflow_status: :failed)
283
- dataset.update(workflow_status: :failed)
284
- build_error_with_context(dataset, e, batch_id, feature)
280
+ if dataset.reload.workflow_status != :failed
281
+ feature.update(workflow_status: :failed)
282
+ dataset.update(workflow_status: :failed)
283
+ EasyML::Event.handle_error(dataset, e)
284
+ end
285
285
  end
286
286
  raise e
287
287
  end
288
288
  end
289
289
 
290
- def self.build_error_with_context(dataset, error, batch_id, feature)
291
- error = EasyML::Event.handle_error(dataset, error)
292
- batch = feature.build_batch(batch_id: batch_id)
293
-
294
- # Convert any dataframes in the context to serialized form
295
- error.create_context(context: batch)
296
- end
297
-
298
290
  def self.fit_feature_failed(dataset, e)
299
291
  dataset.update(workflow_status: :failed)
300
292
  EasyML::Event.handle_error(dataset, e)
@@ -31,12 +31,13 @@ module EasyML
31
31
  }
32
32
  existing_lineage = existing_lineage.map do |key, lineage|
33
33
  matching_lineage = @lineage.detect { |ll| ll[:key].to_sym == lineage.key.to_sym }
34
+ next unless matching_lineage.present?
34
35
 
35
36
  lineage&.assign_attributes(
36
37
  occurred_at: matching_lineage[:occurred_at],
37
38
  description: matching_lineage[:description],
38
39
  )
39
- end
40
+ end.compact
40
41
  missing_lineage.concat(existing_lineage)
41
42
  end
42
43
  end
@@ -40,6 +40,7 @@ module EasyML
40
40
  end
41
41
  end
42
42
 
43
+ # STOP CHECKING S3 IN BETWEEN ITERATIONS... FIND WHERE REFRESH IS GETTING CALLED
43
44
  def after_iteration(booster, epoch, history)
44
45
  return false unless wandb_enabled?
45
46
 
@@ -320,7 +320,10 @@ module EasyML
320
320
  raise "Cannot predict on nil — XGBoost" if xs.nil?
321
321
 
322
322
  begin
323
+ @predicting = true
323
324
  y_pred = yield(preprocess(xs))
325
+ @predicting = false
326
+ y_pred
324
327
  rescue StandardError => e
325
328
  raise e unless e.message.match?(/Number of columns does not match/)
326
329
 
@@ -495,12 +498,14 @@ module EasyML
495
498
  feature_cols -= [weights_col] if weights_col
496
499
 
497
500
  # Get features, labels and weights
498
- exploded = explode_embeddings(xs.select(feature_cols))
501
+ exploded = explode_embeddings(xs)
499
502
  feature_cols = exploded.columns
500
503
  features = lazy ? exploded.collect.to_numo : exploded.to_numo
501
504
 
502
- weights = weights_col ? (lazy ? xs.select(weights_col).collect.to_numo : xs.select(weights_col).to_numo) : nil
503
- weights = weights.flatten if weights
505
+ unless @predicting
506
+ weights = weights_col ? (lazy ? xs.select(weights_col).collect.to_numo : xs.select(weights_col).to_numo) : nil
507
+ weights = weights.flatten if weights
508
+ end
504
509
  if ys.present?
505
510
  ys = ys.is_a?(Array) ? Polars::Series.new(ys) : ys
506
511
  labels = lazy ? ys.collect.to_numo.flatten : ys.to_numo.flatten
@@ -28,7 +28,7 @@ module EasyML
28
28
 
29
29
  def prediction
30
30
  prediction_value["value"]
31
- end
31
+ e end
32
32
 
33
33
  def probabilities
34
34
  metadata["probabilities"]
@@ -6,18 +6,14 @@ module EasyML
6
6
 
7
7
  attr_reader :splitter
8
8
 
9
- def split(datasource, &block)
10
- datasource.in_batches do |df|
11
- split_df(df).tap do |splits|
12
- yield splits if block_given?
13
- end
14
- end
15
- end
16
-
17
9
  def split_df(df)
18
10
  df
19
11
  end
20
12
 
13
+ def split(dataset)
14
+ split_df(dataset.materialized_view)
15
+ end
16
+
21
17
  def initialize(splitter)
22
18
  @splitter = splitter
23
19
  end
@@ -41,9 +41,10 @@ module EasyML
41
41
 
42
42
  validation_date_start, test_date_start = splits
43
43
 
44
+ dtype = df[date_col].dtype
44
45
  test_df = Polars.concat(
45
46
  [
46
- df.filter(Polars.col(date_col) >= test_date_start),
47
+ df.filter(Polars.col(date_col).ge(Polars.lit(test_date_start).cast(dtype))),
47
48
  df.filter(Polars.col(date_col).is_null),
48
49
  ]
49
50
  )
@@ -15,13 +15,18 @@ module EasyML
15
15
  }
16
16
  end
17
17
 
18
- def split(datasource, &block)
18
+ def split(dataset, &block)
19
19
  validate!
20
20
 
21
- files = datasource.all_files
21
+ files = dataset.datasource.all_files
22
22
  train, valid, test = match_files(files)
23
23
 
24
- yield [reader.query(train), reader.query(valid), reader.query(test)]
24
+ values = [reader.query(train), reader.query(valid), reader.query(test)]
25
+ if block_given?
26
+ yield values
27
+ else
28
+ values
29
+ end
25
30
  end
26
31
 
27
32
  def match_files(files)