easy_ml 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/README.md +270 -0
- data/Rakefile +12 -0
- data/app/models/easy_ml/model.rb +59 -0
- data/app/models/easy_ml/models/xgboost.rb +9 -0
- data/app/models/easy_ml/models.rb +5 -0
- data/lib/easy_ml/core/model.rb +29 -0
- data/lib/easy_ml/core/model_core.rb +181 -0
- data/lib/easy_ml/core/model_evaluator.rb +137 -0
- data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
- data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
- data/lib/easy_ml/core/models/xgboost.rb +10 -0
- data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
- data/lib/easy_ml/core/models.rb +10 -0
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
- data/lib/easy_ml/core/tuner/adapters.rb +10 -0
- data/lib/easy_ml/core/tuner.rb +105 -0
- data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
- data/lib/easy_ml/core/uploaders.rb +7 -0
- data/lib/easy_ml/core.rb +9 -0
- data/lib/easy_ml/core_ext/pathname.rb +9 -0
- data/lib/easy_ml/core_ext.rb +5 -0
- data/lib/easy_ml/data/dataloader.rb +6 -0
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
- data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
- data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
- data/lib/easy_ml/data/dataset/splits.rb +11 -0
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
- data/lib/easy_ml/data/dataset/splitters.rb +9 -0
- data/lib/easy_ml/data/dataset.rb +430 -0
- data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
- data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
- data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
- data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
- data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
- data/lib/easy_ml/data/datasource.rb +33 -0
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
- data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
- data/lib/easy_ml/data/preprocessor.rb +238 -0
- data/lib/easy_ml/data/utils.rb +50 -0
- data/lib/easy_ml/data.rb +8 -0
- data/lib/easy_ml/deployment.rb +5 -0
- data/lib/easy_ml/engine.rb +26 -0
- data/lib/easy_ml/initializers/inflections.rb +4 -0
- data/lib/easy_ml/logging.rb +38 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
- data/lib/easy_ml/support/age.rb +27 -0
- data/lib/easy_ml/support/est.rb +1 -0
- data/lib/easy_ml/support/file_rotate.rb +23 -0
- data/lib/easy_ml/support/git_ignorable.rb +66 -0
- data/lib/easy_ml/support/synced_directory.rb +134 -0
- data/lib/easy_ml/support/utc.rb +1 -0
- data/lib/easy_ml/support.rb +10 -0
- data/lib/easy_ml/trainer.rb +92 -0
- data/lib/easy_ml/transforms.rb +29 -0
- data/lib/easy_ml/version.rb +5 -0
- data/lib/easy_ml.rb +23 -0
- metadata +353 -0
@@ -0,0 +1,403 @@
|
|
1
|
+
require "active_support/core_ext/hash/deep_transform_values"
|
2
|
+
require "numo/narray"
|
3
|
+
require "json"
|
4
|
+
|
5
|
+
module EasyML
|
6
|
+
module Data
|
7
|
+
class Preprocessor
|
8
|
+
class SimpleImputer
|
9
|
+
attr_reader :statistics
|
10
|
+
attr_accessor :path, :attribute, :strategy, :options
|
11
|
+
|
12
|
+
def initialize(strategy: "mean", path: nil, attribute: nil, options: {}, &block)
|
13
|
+
@strategy = strategy.to_sym
|
14
|
+
@path = path
|
15
|
+
@attribute = attribute
|
16
|
+
@options = options || {}
|
17
|
+
apply_defaults
|
18
|
+
load
|
19
|
+
@statistics ||= {}
|
20
|
+
deep_symbolize_keys!
|
21
|
+
return unless block_given?
|
22
|
+
|
23
|
+
instance_eval(&block)
|
24
|
+
end
|
25
|
+
|
26
|
+
def deep_symbolize_keys!
|
27
|
+
@statistics = @statistics.deep_symbolize_keys
|
28
|
+
end
|
29
|
+
|
30
|
+
def apply_defaults
|
31
|
+
@options[:date_column] ||= "CREATED_DATE"
|
32
|
+
|
33
|
+
if strategy == :categorical
|
34
|
+
@options[:categorical_min] ||= 25
|
35
|
+
elsif strategy == :custom
|
36
|
+
itself = ->(col) { col }
|
37
|
+
@options[:fit] ||= itself
|
38
|
+
@options[:transform] ||= itself
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def fit(x, df = nil)
|
43
|
+
x = validate_input(x)
|
44
|
+
|
45
|
+
fit_values = case @strategy
|
46
|
+
when :mean
|
47
|
+
fit_mean(x)
|
48
|
+
when :median
|
49
|
+
fit_median(x)
|
50
|
+
when :ffill
|
51
|
+
fit_ffill(x, df)
|
52
|
+
when :most_frequent
|
53
|
+
fit_most_frequent(x)
|
54
|
+
when :categorical
|
55
|
+
fit_categorical(x)
|
56
|
+
when :constant
|
57
|
+
fit_constant(x)
|
58
|
+
when :clip
|
59
|
+
fit_no_op(x)
|
60
|
+
when :today
|
61
|
+
fit_no_op(x)
|
62
|
+
when :one_hot
|
63
|
+
fit_no_op(x)
|
64
|
+
when :custom
|
65
|
+
fit_custom(x)
|
66
|
+
else
|
67
|
+
raise ArgumentError, "Invalid strategy: #{@strategy}"
|
68
|
+
end || {}
|
69
|
+
|
70
|
+
@statistics[attribute] ||= {}
|
71
|
+
@statistics[attribute][@strategy] = fit_values.merge!(original_dtype: x.dtype)
|
72
|
+
save
|
73
|
+
self
|
74
|
+
end
|
75
|
+
|
76
|
+
def transform(x)
|
77
|
+
check_is_fitted
|
78
|
+
|
79
|
+
if x.is_a?(Polars::Series)
|
80
|
+
transform_polars(x)
|
81
|
+
else
|
82
|
+
transform_dense(x)
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
def transform_polars(x)
|
87
|
+
result = case @strategy
|
88
|
+
when :mean, :median, :ffill, :most_frequent, :constant
|
89
|
+
x.fill_null(@statistics[@strategy][:value])
|
90
|
+
when :clip
|
91
|
+
min = options["min"] || 0
|
92
|
+
max = options["max"] || 1_000_000_000_000
|
93
|
+
if x.null_count != x.len
|
94
|
+
x.clip(min, max)
|
95
|
+
else
|
96
|
+
x
|
97
|
+
end
|
98
|
+
when :categorical
|
99
|
+
allowed_values = @statistics.dig(:categorical, :value).select do |_k, v|
|
100
|
+
v >= options[:categorical_min]
|
101
|
+
end.keys.map(&:to_s)
|
102
|
+
if x.null_count == x.len
|
103
|
+
x.fill_null(transform_categorical(nil))
|
104
|
+
else
|
105
|
+
x.apply do |val|
|
106
|
+
allowed_values.include?(val) ? val : transform_categorical(val)
|
107
|
+
end
|
108
|
+
end
|
109
|
+
when :today
|
110
|
+
x.fill_null(transform_today(nil))
|
111
|
+
when :custom
|
112
|
+
if x.null_count == x.len
|
113
|
+
x.fill_null(transform_custom(nil))
|
114
|
+
else
|
115
|
+
x.apply do |val|
|
116
|
+
should_transform_custom?(val) ? transform_custom(val) : val
|
117
|
+
end
|
118
|
+
end
|
119
|
+
else
|
120
|
+
raise ArgumentError, "Unsupported strategy for Polars::Series: #{@strategy}"
|
121
|
+
end
|
122
|
+
|
123
|
+
# Cast the result back to the original dtype
|
124
|
+
original_dtype = @statistics.dig(@strategy, :original_dtype)
|
125
|
+
original_dtype ? result.cast(original_dtype) : result
|
126
|
+
end
|
127
|
+
|
128
|
+
def file_path
|
129
|
+
raise "Need both attribute and path to save/load statistics" unless attribute.present? && path.to_s.present?
|
130
|
+
|
131
|
+
File.join(path, "statistics.json")
|
132
|
+
end
|
133
|
+
|
134
|
+
def cleanup
|
135
|
+
@statistics = {}
|
136
|
+
FileUtils.rm(file_path) if File.exist?(file_path)
|
137
|
+
end
|
138
|
+
|
139
|
+
def save
|
140
|
+
FileUtils.mkdir_p(File.dirname(file_path))
|
141
|
+
|
142
|
+
all_statistics = (File.exist?(file_path) ? JSON.parse(File.read(file_path)) : {}).deep_symbolize_keys
|
143
|
+
|
144
|
+
deep_symbolize_keys!
|
145
|
+
|
146
|
+
serialized = serialize_statistics(@statistics)
|
147
|
+
all_statistics[attribute] = {} unless all_statistics.key?(attribute)
|
148
|
+
all_statistics[attribute][@strategy] = serialized[attribute.to_sym][@strategy.to_sym]
|
149
|
+
|
150
|
+
File.open(file_path, "w") do |file|
|
151
|
+
file.write(JSON.pretty_generate(all_statistics))
|
152
|
+
end
|
153
|
+
end
|
154
|
+
|
155
|
+
def load
|
156
|
+
return unless File.exist?(file_path)
|
157
|
+
|
158
|
+
all_statistics = JSON.parse(File.read(file_path))
|
159
|
+
attribute_stats = all_statistics[@attribute]
|
160
|
+
|
161
|
+
return unless attribute_stats
|
162
|
+
|
163
|
+
@statistics = deserialize_statistics(attribute_stats)
|
164
|
+
deep_symbolize_keys!
|
165
|
+
end
|
166
|
+
|
167
|
+
def should_transform_categorical?(val)
|
168
|
+
values = @statistics.dig(:categorical, :value) || {}
|
169
|
+
min_ct = options[:categorical_min] || 25
|
170
|
+
allowed_values = values.select { |_v, c| c >= min_ct }
|
171
|
+
|
172
|
+
allowed_values.keys.map(&:to_s).exclude?(val)
|
173
|
+
end
|
174
|
+
|
175
|
+
def transform_categorical(val)
|
176
|
+
return "other" if val.nil?
|
177
|
+
|
178
|
+
values = @statistics.dig(:categorical, :value) || {}
|
179
|
+
min_ct = options[:categorical_min] || 25
|
180
|
+
allowed_values = values.select { |_v, c| c >= min_ct }.keys.map(&:to_s)
|
181
|
+
|
182
|
+
binding.pry
|
183
|
+
allowed_values.include?(val.to_s) ? val.to_s : "other"
|
184
|
+
end
|
185
|
+
|
186
|
+
def transform_today(_val)
|
187
|
+
EST.now.beginning_of_day
|
188
|
+
end
|
189
|
+
|
190
|
+
def fit_custom(x)
|
191
|
+
x
|
192
|
+
end
|
193
|
+
|
194
|
+
def should_transform_custom?(x)
|
195
|
+
if options.key?(:should_transform)
|
196
|
+
options[:should_transform].call(x)
|
197
|
+
else
|
198
|
+
should_transform_default?(x)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
def transform_custom(x)
|
203
|
+
raise "Transform required" unless options.key?(:transform)
|
204
|
+
|
205
|
+
options[:transform].call(x)
|
206
|
+
end
|
207
|
+
|
208
|
+
private
|
209
|
+
|
210
|
+
def validate_input(x)
|
211
|
+
raise ArgumentError, "Input must be a Polars::Series" unless x.is_a?(Polars::Series)
|
212
|
+
|
213
|
+
x
|
214
|
+
end
|
215
|
+
|
216
|
+
def fit_mean(x)
|
217
|
+
{ value: x.mean }
|
218
|
+
end
|
219
|
+
|
220
|
+
def fit_median(x)
|
221
|
+
{ value: x.median }
|
222
|
+
end
|
223
|
+
|
224
|
+
def fit_ffill(x, df)
|
225
|
+
values = { value: nil, max_date: nil }
|
226
|
+
|
227
|
+
date_col = df[options[:date_column]]
|
228
|
+
return if date_col.is_null.all
|
229
|
+
|
230
|
+
sorted_df = df.sort(options[:date_column])
|
231
|
+
new_max_date = sorted_df[options[:date_column]].max
|
232
|
+
|
233
|
+
current_max_date = values[:max_date]
|
234
|
+
return if current_max_date && current_max_date > new_max_date
|
235
|
+
|
236
|
+
values[:max_date] = [current_max_date, new_max_date].compact.max
|
237
|
+
|
238
|
+
# Get the last non-null value
|
239
|
+
last_non_null = sorted_df[x.name].filter(sorted_df[x.name].is_not_null).tail(1).to_a.first
|
240
|
+
values[:value] = last_non_null
|
241
|
+
|
242
|
+
values
|
243
|
+
end
|
244
|
+
|
245
|
+
def fit_most_frequent(x)
|
246
|
+
value_counts = x.filter(x.is_not_null).value_counts
|
247
|
+
column_names = value_counts.columns
|
248
|
+
column_names[0]
|
249
|
+
count_column = column_names[1]
|
250
|
+
|
251
|
+
most_frequent_value = value_counts.sort(count_column, descending: true).row(0)[0]
|
252
|
+
{ value: most_frequent_value }
|
253
|
+
end
|
254
|
+
|
255
|
+
def fit_categorical(x)
|
256
|
+
value_counts = x.value_counts
|
257
|
+
column_names = value_counts.columns
|
258
|
+
value_column = column_names[0]
|
259
|
+
count_column = column_names[1]
|
260
|
+
|
261
|
+
as_hash = value_counts.select([value_column, count_column]).rows.to_a.to_h.transform_keys(&:to_s)
|
262
|
+
label_encoder = as_hash.keys.sort.each.with_index.reduce({}) do |h, (k, i)|
|
263
|
+
h.tap do
|
264
|
+
h[k] = i
|
265
|
+
end
|
266
|
+
end
|
267
|
+
label_decoder = label_encoder.invert
|
268
|
+
|
269
|
+
{
|
270
|
+
value: as_hash,
|
271
|
+
label_encoder: label_encoder,
|
272
|
+
label_decoder: label_decoder
|
273
|
+
}
|
274
|
+
end
|
275
|
+
|
276
|
+
def fit_no_op(_x)
|
277
|
+
{}
|
278
|
+
end
|
279
|
+
|
280
|
+
def fit_constant(_x)
|
281
|
+
{ value: @options[:fill_value] }
|
282
|
+
end
|
283
|
+
|
284
|
+
def transform_default(_val)
|
285
|
+
@statistics[strategy][:value]
|
286
|
+
end
|
287
|
+
|
288
|
+
def should_transform_default?(val)
|
289
|
+
checker_method = val.respond_to?(:nan?) ? :nan? : :nil?
|
290
|
+
val.send(checker_method)
|
291
|
+
end
|
292
|
+
|
293
|
+
def transform_dense(x)
|
294
|
+
result = x.map do |val|
|
295
|
+
strategy_method = respond_to?("transform_#{strategy}") ? "transform_#{strategy}" : "transform_default"
|
296
|
+
checker_method = respond_to?("should_transform_#{strategy}?") ? "should_transform_#{strategy}?" : "should_transform_default?"
|
297
|
+
send(checker_method, val) ? send(strategy_method, val) : val
|
298
|
+
end
|
299
|
+
|
300
|
+
# Cast the result back to the original dtype
|
301
|
+
original_dtype = @statistics[:original_dtype]
|
302
|
+
if original_dtype
|
303
|
+
result.map { |val| cast_to_dtype(val, original_dtype) }
|
304
|
+
else
|
305
|
+
result
|
306
|
+
end
|
307
|
+
end
|
308
|
+
|
309
|
+
def check_is_fitted
|
310
|
+
return if %i[clip today custom].include?(strategy)
|
311
|
+
|
312
|
+
raise "SimpleImputer has not been fitted yet for #{attribute}##{strategy}" unless @statistics[strategy]
|
313
|
+
end
|
314
|
+
|
315
|
+
def serialize_statistics(stats)
|
316
|
+
stats.deep_transform_values do |value|
|
317
|
+
case value
|
318
|
+
when Time, DateTime
|
319
|
+
{ "__type__" => "datetime", "value" => value.iso8601 }
|
320
|
+
when Date
|
321
|
+
{ "__type__" => "date", "value" => value.iso8601 }
|
322
|
+
when BigDecimal
|
323
|
+
{ "__type__" => "bigdecimal", "value" => value.to_s }
|
324
|
+
when Polars::DataType
|
325
|
+
{ "__type__" => "polars_dtype", "value" => value.to_s }
|
326
|
+
when Symbol
|
327
|
+
{ "__type__" => "symbol", "value" => value.to_s }
|
328
|
+
else
|
329
|
+
value
|
330
|
+
end
|
331
|
+
end
|
332
|
+
end
|
333
|
+
|
334
|
+
def deserialize_statistics(stats)
|
335
|
+
stats.transform_values do |value|
|
336
|
+
recursive_deserialize(value)
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
def recursive_deserialize(value)
|
341
|
+
case value
|
342
|
+
when Hash
|
343
|
+
if value["__type__"]
|
344
|
+
deserialize_special_type(value)
|
345
|
+
else
|
346
|
+
value.transform_values { |v| recursive_deserialize(v) }
|
347
|
+
end
|
348
|
+
when Array
|
349
|
+
value.map { |v| recursive_deserialize(v) }
|
350
|
+
else
|
351
|
+
value
|
352
|
+
end
|
353
|
+
end
|
354
|
+
|
355
|
+
def deserialize_special_type(value)
|
356
|
+
case value["__type__"]
|
357
|
+
when "datetime"
|
358
|
+
DateTime.parse(value["value"])
|
359
|
+
when "date"
|
360
|
+
Date.parse(value["value"])
|
361
|
+
when "bigdecimal"
|
362
|
+
BigDecimal(value["value"])
|
363
|
+
when "polars_dtype"
|
364
|
+
parse_polars_dtype(value["value"])
|
365
|
+
when "symbol"
|
366
|
+
value["value"].to_sym
|
367
|
+
else
|
368
|
+
value["value"]
|
369
|
+
end
|
370
|
+
end
|
371
|
+
|
372
|
+
def parse_polars_dtype(dtype_string)
|
373
|
+
case dtype_string
|
374
|
+
when /^Polars::Datetime/
|
375
|
+
time_unit = dtype_string[/time_unit: "(.*?)"/, 1]
|
376
|
+
time_zone = dtype_string[/time_zone: (.*)?\)/, 1]
|
377
|
+
time_zone = time_zone == "nil" ? nil : time_zone&.delete('"')
|
378
|
+
Polars::Datetime.new(time_unit: time_unit, time_zone: time_zone).class
|
379
|
+
when /^Polars::/
|
380
|
+
Polars.const_get(dtype_string.split("::").last)
|
381
|
+
else
|
382
|
+
raise ArgumentError, "Unknown Polars data type: #{dtype_string}"
|
383
|
+
end
|
384
|
+
end
|
385
|
+
|
386
|
+
def cast_to_dtype(value, dtype)
|
387
|
+
case dtype
|
388
|
+
when Polars::Int64
|
389
|
+
value.to_i
|
390
|
+
when Polars::Float64
|
391
|
+
value.to_f
|
392
|
+
when Polars::Boolean
|
393
|
+
!!value
|
394
|
+
when Polars::Utf8
|
395
|
+
value.to_s
|
396
|
+
else
|
397
|
+
value
|
398
|
+
end
|
399
|
+
end
|
400
|
+
end
|
401
|
+
end
|
402
|
+
end
|
403
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module EasyML::Data
|
2
|
+
class Preprocessor
|
3
|
+
module Utils
|
4
|
+
def standardize_config(config)
|
5
|
+
config.each do |column, strategies|
|
6
|
+
next unless strategies.is_a?(Array)
|
7
|
+
|
8
|
+
config[column] = strategies.reduce({}) do |hash, strategy|
|
9
|
+
hash.tap do
|
10
|
+
hash[strategy] = true
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,238 @@
|
|
1
|
+
require "fileutils"
|
2
|
+
require "polars"
|
3
|
+
require "date"
|
4
|
+
require "json"
|
5
|
+
require_relative "preprocessor/utils"
|
6
|
+
require_relative "preprocessor/simple_imputer"
|
7
|
+
|
8
|
+
module EasyML::Data
|
9
|
+
class Preprocessor
|
10
|
+
include GlueGun::DSL
|
11
|
+
include EasyML::Data::Preprocessor::Utils
|
12
|
+
|
13
|
+
CATEGORICAL_COMMON_MIN = 50
|
14
|
+
PREPROCESSING_ORDER = %w[clip mean median constant categorical one_hot ffill custom fill_date add_datepart]
|
15
|
+
|
16
|
+
attribute :directory, :string
|
17
|
+
attribute :verbose, :boolean, default: false
|
18
|
+
attribute :preprocessing_steps, :hash, default: {}
|
19
|
+
def preprocessing_steps=(preprocessing_steps)
|
20
|
+
super(standardize_config(preprocessing_steps).with_indifferent_access)
|
21
|
+
end
|
22
|
+
|
23
|
+
def fit(df)
|
24
|
+
return if df.nil?
|
25
|
+
return if preprocessing_steps.keys.none?
|
26
|
+
|
27
|
+
puts "Preprocessing..." if verbose
|
28
|
+
imputers = initialize_imputers(
|
29
|
+
preprocessing_steps[:training].merge!(preprocessing_steps[:inference] || {})
|
30
|
+
)
|
31
|
+
|
32
|
+
did_cleanup = false
|
33
|
+
imputers.each do |col, imputers|
|
34
|
+
sorted_strategies(imputers).each do |strategy|
|
35
|
+
imputer = imputers[strategy]
|
36
|
+
unless did_cleanup
|
37
|
+
imputer.cleanup
|
38
|
+
did_cleanup = true
|
39
|
+
end
|
40
|
+
if df.columns.map(&:downcase).include?(col.downcase)
|
41
|
+
actual_col = df.columns.find { |c| c.downcase == imputer.attribute.downcase }
|
42
|
+
imputer.fit(df[actual_col], df)
|
43
|
+
if strategy == "clip" # This is the only one to transform during fit
|
44
|
+
df[actual_col] = imputer.transform(df[actual_col])
|
45
|
+
end
|
46
|
+
elsif @verbose
|
47
|
+
puts "Warning: Column '#{col}' not found in DataFrame during fit process."
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def postprocess(df, inference: false)
|
54
|
+
puts "Postprocessing..." if verbose
|
55
|
+
return df if preprocessing_steps.keys.none?
|
56
|
+
|
57
|
+
steps = if inference
|
58
|
+
preprocessing_steps[:training].merge(preprocessing_steps[:inference] || {})
|
59
|
+
else
|
60
|
+
preprocessing_steps[:training]
|
61
|
+
end
|
62
|
+
|
63
|
+
df = apply_transformations(df, steps)
|
64
|
+
|
65
|
+
puts "Postprocessing complete." if @verbose
|
66
|
+
df
|
67
|
+
end
|
68
|
+
|
69
|
+
def statistics
|
70
|
+
initialize_imputers(preprocessing_steps[:training]).each_with_object({}) do |(col, strategies), result|
|
71
|
+
result[col] = strategies.each_with_object({}) do |(strategy, imputer), col_result|
|
72
|
+
col_result[strategy] = imputer.statistics
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
def is_fit?
|
78
|
+
statistics.any? { |_, col_stats| col_stats.any? { |_, strategy_stats| strategy_stats.present? } }
|
79
|
+
end
|
80
|
+
|
81
|
+
def delete
|
82
|
+
return unless File.directory?(@directory)
|
83
|
+
|
84
|
+
FileUtils.rm_rf(@directory)
|
85
|
+
end
|
86
|
+
|
87
|
+
def move(to)
|
88
|
+
old_dir = directory
|
89
|
+
current_env = directory.split("/")[-1]
|
90
|
+
new_dir = directory.gsub(Regexp.new(current_env), to)
|
91
|
+
|
92
|
+
puts "Moving #{old_dir} to #{new_dir}"
|
93
|
+
FileUtils.mv(old_dir, new_dir)
|
94
|
+
@directory = new_dir
|
95
|
+
end
|
96
|
+
|
97
|
+
def decode_labels(values, col: nil)
|
98
|
+
imputers = initialize_imputers(preprocessing_steps[:training], dumb: true)
|
99
|
+
imputer = imputers.dig(col, "categorical")
|
100
|
+
decoder = imputer.statistics.dig(:categorical, :label_decoder)
|
101
|
+
|
102
|
+
other_value = decoder.keys.map(&:to_s).map(&:to_i).max + 1
|
103
|
+
decoder[other_value] = "other"
|
104
|
+
decoder.stringify_keys!
|
105
|
+
|
106
|
+
values.map do |value|
|
107
|
+
decoder[value.to_s]
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
private
|
112
|
+
|
113
|
+
def initialize_imputers(config, dumb: false)
|
114
|
+
standardize_config(config).each_with_object({}) do |(col, strategies), hash|
|
115
|
+
hash[col] ||= {}
|
116
|
+
strategies.each do |strategy, options|
|
117
|
+
options = {} if options == true
|
118
|
+
|
119
|
+
hash[col][strategy] = EasyML::Data::Preprocessor::SimpleImputer.new(
|
120
|
+
strategy: strategy,
|
121
|
+
path: directory,
|
122
|
+
attribute: col,
|
123
|
+
options: options
|
124
|
+
)
|
125
|
+
end
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
def apply_transformations(df, config)
|
130
|
+
imputers = initialize_imputers(config)
|
131
|
+
|
132
|
+
standardize_config(config).each do |col, strategies|
|
133
|
+
if df.columns.map(&:downcase).include?(col.downcase)
|
134
|
+
actual_col = df.columns.find { |c| c.downcase == col.downcase }
|
135
|
+
|
136
|
+
sorted_strategies(strategies).each do |strategy|
|
137
|
+
if strategy.to_sym == :categorical
|
138
|
+
if imputers.dig(col, strategy).options.dig("one_hot")
|
139
|
+
df = apply_one_hot(df, col, imputers)
|
140
|
+
elsif imputers.dig(col, strategy).options.dig("encode_labels")
|
141
|
+
df = apply_encode_labels(df, col, imputers)
|
142
|
+
end
|
143
|
+
else
|
144
|
+
imputer = imputers.dig(col, strategy)
|
145
|
+
df[actual_col] = imputer.transform(df[actual_col]) if imputer
|
146
|
+
end
|
147
|
+
end
|
148
|
+
elsif @verbose
|
149
|
+
puts "Warning: Column '#{col}' not found in DataFrame during apply_transformations process."
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
df
|
154
|
+
end
|
155
|
+
|
156
|
+
def apply_one_hot(df, col, imputers)
|
157
|
+
cat_imputer = imputers.dig(col, "categorical")
|
158
|
+
approved_values = cat_imputer.statistics[:categorical][:value].select do |_k, v|
|
159
|
+
v >= cat_imputer.options["categorical_min"]
|
160
|
+
end.keys
|
161
|
+
|
162
|
+
# Create one-hot encoded columns
|
163
|
+
approved_values.each do |value|
|
164
|
+
new_col_name = "#{col}_#{value}".tr("-", "_")
|
165
|
+
df = df.with_column(
|
166
|
+
df[col].eq(value.to_s).cast(Polars::Int64).alias(new_col_name)
|
167
|
+
)
|
168
|
+
end
|
169
|
+
|
170
|
+
# Create 'other' column for unapproved values
|
171
|
+
other_col_name = "#{col}_other"
|
172
|
+
df[other_col_name] = df[col].map_elements do |value|
|
173
|
+
approved_values.map(&:to_s).exclude?(value)
|
174
|
+
end.cast(Polars::Int64)
|
175
|
+
df.drop([col])
|
176
|
+
end
|
177
|
+
|
178
|
+
def apply_encode_labels(df, col, imputers)
|
179
|
+
cat_imputer = imputers.dig(col, "categorical")
|
180
|
+
approved_values = cat_imputer.statistics[:categorical][:value].select do |_k, v|
|
181
|
+
v >= cat_imputer.options["categorical_min"]
|
182
|
+
end.keys
|
183
|
+
|
184
|
+
df.with_column(
|
185
|
+
df[col].map_elements do |value|
|
186
|
+
approved_values.map(&:to_s).exclude?(value) ? "other" : value
|
187
|
+
end.alias(col)
|
188
|
+
)
|
189
|
+
|
190
|
+
label_encoder = cat_imputer.statistics[:categorical][:label_encoder].stringify_keys
|
191
|
+
other_value = label_encoder.values.max + 1
|
192
|
+
label_encoder["other"] = other_value
|
193
|
+
|
194
|
+
df.with_column(
|
195
|
+
df[col].map { |v| label_encoder[v.to_s] }.alias(col)
|
196
|
+
)
|
197
|
+
end
|
198
|
+
|
199
|
+
def sorted_strategies(strategies)
|
200
|
+
strategies.keys.sort_by do |key|
|
201
|
+
PREPROCESSING_ORDER.index(key)
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def prepare_for_imputation(df, col)
|
206
|
+
df = df.with_column(Polars.col(col).cast(Polars::Float64))
|
207
|
+
df.with_column(Polars.when(Polars.col(col).is_null).then(Float::NAN).otherwise(Polars.col(col)).alias(col))
|
208
|
+
end
|
209
|
+
end
|
210
|
+
end
|
211
|
+
|
212
|
+
# Where to put this???
|
213
|
+
#
|
214
|
+
# def self.stage_required_files
|
215
|
+
# required_files.each do |file|
|
216
|
+
# git_add(file)
|
217
|
+
# end
|
218
|
+
# end
|
219
|
+
|
220
|
+
# def self.git_add(path)
|
221
|
+
# command = "git add #{path}"
|
222
|
+
# puts command if verbose
|
223
|
+
# result = `#{command}`
|
224
|
+
# puts result if verbose
|
225
|
+
# end
|
226
|
+
|
227
|
+
# def self.set_verbose(verbose)
|
228
|
+
# @verbose = verbose
|
229
|
+
# end
|
230
|
+
|
231
|
+
# def required_files
|
232
|
+
# files = Dir.entries(@directory) - %w[. ..]
|
233
|
+
# required_file_types = %w[bin]
|
234
|
+
|
235
|
+
# files.select { |file| required_file_types.any? { |ext| file.include?(ext) } }.map do |file|
|
236
|
+
# File.join(@directory, file)
|
237
|
+
# end
|
238
|
+
# end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module Utils
|
4
|
+
def append_to_csv(df, path)
|
5
|
+
return if df.empty?
|
6
|
+
|
7
|
+
path = Pathname.new(path) if path.is_a?(String)
|
8
|
+
FileUtils.mkdir_p(path.dirname)
|
9
|
+
FileUtils.touch(path)
|
10
|
+
|
11
|
+
# Check if the file is empty (i.e., if this is the first write)
|
12
|
+
file_empty = File.zero?(path)
|
13
|
+
|
14
|
+
# Write the DataFrame to a temporary file
|
15
|
+
temp_file = "#{path}.tmp"
|
16
|
+
df.write_csv(temp_file)
|
17
|
+
|
18
|
+
# Append the content to the main file, skipping the header if not the first write
|
19
|
+
File.open(path, "a") do |f|
|
20
|
+
File.foreach(temp_file).with_index do |line, index|
|
21
|
+
# Skip the header line if the file is not empty
|
22
|
+
f.write(line) unless index == 0 && !file_empty
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Delete the temporary file
|
27
|
+
File.delete(temp_file)
|
28
|
+
end
|
29
|
+
|
30
|
+
def expand_dir(dir)
|
31
|
+
return dir if dir.to_s[0] == "/"
|
32
|
+
|
33
|
+
Rails.root.join(dir)
|
34
|
+
end
|
35
|
+
|
36
|
+
def null_check(df)
|
37
|
+
result = {}
|
38
|
+
null_counts = df.null_count
|
39
|
+
total_count = df.height
|
40
|
+
df.columns.each do |column|
|
41
|
+
null_count = null_counts[column][0]
|
42
|
+
next if null_count == 0
|
43
|
+
|
44
|
+
result[column] = { null_count: null_count, total_count: total_count }
|
45
|
+
end
|
46
|
+
result.empty? ? nil : result
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|