easy_ml 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +270 -0
- data/Rakefile +12 -0
- data/app/models/easy_ml/model.rb +59 -0
- data/app/models/easy_ml/models/xgboost.rb +9 -0
- data/app/models/easy_ml/models.rb +5 -0
- data/lib/easy_ml/core/model.rb +29 -0
- data/lib/easy_ml/core/model_core.rb +181 -0
- data/lib/easy_ml/core/model_evaluator.rb +137 -0
- data/lib/easy_ml/core/models/hyperparameters/base.rb +34 -0
- data/lib/easy_ml/core/models/hyperparameters/xgboost.rb +19 -0
- data/lib/easy_ml/core/models/hyperparameters.rb +8 -0
- data/lib/easy_ml/core/models/xgboost.rb +10 -0
- data/lib/easy_ml/core/models/xgboost_core.rb +220 -0
- data/lib/easy_ml/core/models.rb +10 -0
- data/lib/easy_ml/core/tuner/adapters/base_adapter.rb +63 -0
- data/lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb +50 -0
- data/lib/easy_ml/core/tuner/adapters.rb +10 -0
- data/lib/easy_ml/core/tuner.rb +105 -0
- data/lib/easy_ml/core/uploaders/model_uploader.rb +24 -0
- data/lib/easy_ml/core/uploaders.rb +7 -0
- data/lib/easy_ml/core.rb +9 -0
- data/lib/easy_ml/core_ext/pathname.rb +9 -0
- data/lib/easy_ml/core_ext.rb +5 -0
- data/lib/easy_ml/data/dataloader.rb +6 -0
- data/lib/easy_ml/data/dataset/data/preprocessor/statistics.json +31 -0
- data/lib/easy_ml/data/dataset/data/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/dataset/files/sample_info.json +1 -0
- data/lib/easy_ml/data/dataset/splits/file_split.rb +140 -0
- data/lib/easy_ml/data/dataset/splits/in_memory_split.rb +49 -0
- data/lib/easy_ml/data/dataset/splits/split.rb +98 -0
- data/lib/easy_ml/data/dataset/splits.rb +11 -0
- data/lib/easy_ml/data/dataset/splitters/date_splitter.rb +43 -0
- data/lib/easy_ml/data/dataset/splitters.rb +9 -0
- data/lib/easy_ml/data/dataset.rb +430 -0
- data/lib/easy_ml/data/datasource/datasource_factory.rb +60 -0
- data/lib/easy_ml/data/datasource/file_datasource.rb +40 -0
- data/lib/easy_ml/data/datasource/merged_datasource.rb +64 -0
- data/lib/easy_ml/data/datasource/polars_datasource.rb +41 -0
- data/lib/easy_ml/data/datasource/s3_datasource.rb +89 -0
- data/lib/easy_ml/data/datasource.rb +33 -0
- data/lib/easy_ml/data/preprocessor/preprocessor.rb +205 -0
- data/lib/easy_ml/data/preprocessor/simple_imputer.rb +403 -0
- data/lib/easy_ml/data/preprocessor/utils.rb +17 -0
- data/lib/easy_ml/data/preprocessor.rb +238 -0
- data/lib/easy_ml/data/utils.rb +50 -0
- data/lib/easy_ml/data.rb +8 -0
- data/lib/easy_ml/deployment.rb +5 -0
- data/lib/easy_ml/engine.rb +26 -0
- data/lib/easy_ml/initializers/inflections.rb +4 -0
- data/lib/easy_ml/logging.rb +38 -0
- data/lib/easy_ml/railtie/generators/migration/migration_generator.rb +42 -0
- data/lib/easy_ml/railtie/templates/migration/create_easy_ml_models.rb.tt +23 -0
- data/lib/easy_ml/support/age.rb +27 -0
- data/lib/easy_ml/support/est.rb +1 -0
- data/lib/easy_ml/support/file_rotate.rb +23 -0
- data/lib/easy_ml/support/git_ignorable.rb +66 -0
- data/lib/easy_ml/support/synced_directory.rb +134 -0
- data/lib/easy_ml/support/utc.rb +1 -0
- data/lib/easy_ml/support.rb +10 -0
- data/lib/easy_ml/trainer.rb +92 -0
- data/lib/easy_ml/transforms.rb +29 -0
- data/lib/easy_ml/version.rb +5 -0
- data/lib/easy_ml.rb +23 -0
- metadata +353 -0
@@ -0,0 +1,403 @@
|
|
1
|
+
require "active_support/core_ext/hash/deep_transform_values"
|
2
|
+
require "numo/narray"
|
3
|
+
require "json"
|
4
|
+
|
5
|
+
module EasyML
|
6
|
+
module Data
|
7
|
+
class Preprocessor
|
8
|
+
class SimpleImputer
|
9
|
+
attr_reader :statistics
|
10
|
+
attr_accessor :path, :attribute, :strategy, :options
|
11
|
+
|
12
|
+
def initialize(strategy: "mean", path: nil, attribute: nil, options: {}, &block)
|
13
|
+
@strategy = strategy.to_sym
|
14
|
+
@path = path
|
15
|
+
@attribute = attribute
|
16
|
+
@options = options || {}
|
17
|
+
apply_defaults
|
18
|
+
load
|
19
|
+
@statistics ||= {}
|
20
|
+
deep_symbolize_keys!
|
21
|
+
return unless block_given?
|
22
|
+
|
23
|
+
instance_eval(&block)
|
24
|
+
end
|
25
|
+
|
26
|
+
def deep_symbolize_keys!
|
27
|
+
@statistics = @statistics.deep_symbolize_keys
|
28
|
+
end
|
29
|
+
|
30
|
+
def apply_defaults
|
31
|
+
@options[:date_column] ||= "CREATED_DATE"
|
32
|
+
|
33
|
+
if strategy == :categorical
|
34
|
+
@options[:categorical_min] ||= 25
|
35
|
+
elsif strategy == :custom
|
36
|
+
itself = ->(col) { col }
|
37
|
+
@options[:fit] ||= itself
|
38
|
+
@options[:transform] ||= itself
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
def fit(x, df = nil)
|
43
|
+
x = validate_input(x)
|
44
|
+
|
45
|
+
fit_values = case @strategy
|
46
|
+
when :mean
|
47
|
+
fit_mean(x)
|
48
|
+
when :median
|
49
|
+
fit_median(x)
|
50
|
+
when :ffill
|
51
|
+
fit_ffill(x, df)
|
52
|
+
when :most_frequent
|
53
|
+
fit_most_frequent(x)
|
54
|
+
when :categorical
|
55
|
+
fit_categorical(x)
|
56
|
+
when :constant
|
57
|
+
fit_constant(x)
|
58
|
+
when :clip
|
59
|
+
fit_no_op(x)
|
60
|
+
when :today
|
61
|
+
fit_no_op(x)
|
62
|
+
when :one_hot
|
63
|
+
fit_no_op(x)
|
64
|
+
when :custom
|
65
|
+
fit_custom(x)
|
66
|
+
else
|
67
|
+
raise ArgumentError, "Invalid strategy: #{@strategy}"
|
68
|
+
end || {}
|
69
|
+
|
70
|
+
@statistics[attribute] ||= {}
|
71
|
+
@statistics[attribute][@strategy] = fit_values.merge!(original_dtype: x.dtype)
|
72
|
+
save
|
73
|
+
self
|
74
|
+
end
|
75
|
+
|
76
|
+
def transform(x)
|
77
|
+
check_is_fitted
|
78
|
+
|
79
|
+
if x.is_a?(Polars::Series)
|
80
|
+
transform_polars(x)
|
81
|
+
else
|
82
|
+
transform_dense(x)
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
def transform_polars(x)
|
87
|
+
result = case @strategy
|
88
|
+
when :mean, :median, :ffill, :most_frequent, :constant
|
89
|
+
x.fill_null(@statistics[@strategy][:value])
|
90
|
+
when :clip
|
91
|
+
min = options["min"] || 0
|
92
|
+
max = options["max"] || 1_000_000_000_000
|
93
|
+
if x.null_count != x.len
|
94
|
+
x.clip(min, max)
|
95
|
+
else
|
96
|
+
x
|
97
|
+
end
|
98
|
+
when :categorical
|
99
|
+
allowed_values = @statistics.dig(:categorical, :value).select do |_k, v|
|
100
|
+
v >= options[:categorical_min]
|
101
|
+
end.keys.map(&:to_s)
|
102
|
+
if x.null_count == x.len
|
103
|
+
x.fill_null(transform_categorical(nil))
|
104
|
+
else
|
105
|
+
x.apply do |val|
|
106
|
+
allowed_values.include?(val) ? val : transform_categorical(val)
|
107
|
+
end
|
108
|
+
end
|
109
|
+
when :today
|
110
|
+
x.fill_null(transform_today(nil))
|
111
|
+
when :custom
|
112
|
+
if x.null_count == x.len
|
113
|
+
x.fill_null(transform_custom(nil))
|
114
|
+
else
|
115
|
+
x.apply do |val|
|
116
|
+
should_transform_custom?(val) ? transform_custom(val) : val
|
117
|
+
end
|
118
|
+
end
|
119
|
+
else
|
120
|
+
raise ArgumentError, "Unsupported strategy for Polars::Series: #{@strategy}"
|
121
|
+
end
|
122
|
+
|
123
|
+
# Cast the result back to the original dtype
|
124
|
+
original_dtype = @statistics.dig(@strategy, :original_dtype)
|
125
|
+
original_dtype ? result.cast(original_dtype) : result
|
126
|
+
end
|
127
|
+
|
128
|
+
def file_path
|
129
|
+
raise "Need both attribute and path to save/load statistics" unless attribute.present? && path.to_s.present?
|
130
|
+
|
131
|
+
File.join(path, "statistics.json")
|
132
|
+
end
|
133
|
+
|
134
|
+
def cleanup
|
135
|
+
@statistics = {}
|
136
|
+
FileUtils.rm(file_path) if File.exist?(file_path)
|
137
|
+
end
|
138
|
+
|
139
|
+
def save
|
140
|
+
FileUtils.mkdir_p(File.dirname(file_path))
|
141
|
+
|
142
|
+
all_statistics = (File.exist?(file_path) ? JSON.parse(File.read(file_path)) : {}).deep_symbolize_keys
|
143
|
+
|
144
|
+
deep_symbolize_keys!
|
145
|
+
|
146
|
+
serialized = serialize_statistics(@statistics)
|
147
|
+
all_statistics[attribute] = {} unless all_statistics.key?(attribute)
|
148
|
+
all_statistics[attribute][@strategy] = serialized[attribute.to_sym][@strategy.to_sym]
|
149
|
+
|
150
|
+
File.open(file_path, "w") do |file|
|
151
|
+
file.write(JSON.pretty_generate(all_statistics))
|
152
|
+
end
|
153
|
+
end
|
154
|
+
|
155
|
+
def load
|
156
|
+
return unless File.exist?(file_path)
|
157
|
+
|
158
|
+
all_statistics = JSON.parse(File.read(file_path))
|
159
|
+
attribute_stats = all_statistics[@attribute]
|
160
|
+
|
161
|
+
return unless attribute_stats
|
162
|
+
|
163
|
+
@statistics = deserialize_statistics(attribute_stats)
|
164
|
+
deep_symbolize_keys!
|
165
|
+
end
|
166
|
+
|
167
|
+
def should_transform_categorical?(val)
|
168
|
+
values = @statistics.dig(:categorical, :value) || {}
|
169
|
+
min_ct = options[:categorical_min] || 25
|
170
|
+
allowed_values = values.select { |_v, c| c >= min_ct }
|
171
|
+
|
172
|
+
allowed_values.keys.map(&:to_s).exclude?(val)
|
173
|
+
end
|
174
|
+
|
175
|
+
def transform_categorical(val)
|
176
|
+
return "other" if val.nil?
|
177
|
+
|
178
|
+
values = @statistics.dig(:categorical, :value) || {}
|
179
|
+
min_ct = options[:categorical_min] || 25
|
180
|
+
allowed_values = values.select { |_v, c| c >= min_ct }.keys.map(&:to_s)
|
181
|
+
|
182
|
+
binding.pry
|
183
|
+
allowed_values.include?(val.to_s) ? val.to_s : "other"
|
184
|
+
end
|
185
|
+
|
186
|
+
def transform_today(_val)
|
187
|
+
EST.now.beginning_of_day
|
188
|
+
end
|
189
|
+
|
190
|
+
def fit_custom(x)
|
191
|
+
x
|
192
|
+
end
|
193
|
+
|
194
|
+
def should_transform_custom?(x)
|
195
|
+
if options.key?(:should_transform)
|
196
|
+
options[:should_transform].call(x)
|
197
|
+
else
|
198
|
+
should_transform_default?(x)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
def transform_custom(x)
|
203
|
+
raise "Transform required" unless options.key?(:transform)
|
204
|
+
|
205
|
+
options[:transform].call(x)
|
206
|
+
end
|
207
|
+
|
208
|
+
private
|
209
|
+
|
210
|
+
def validate_input(x)
|
211
|
+
raise ArgumentError, "Input must be a Polars::Series" unless x.is_a?(Polars::Series)
|
212
|
+
|
213
|
+
x
|
214
|
+
end
|
215
|
+
|
216
|
+
def fit_mean(x)
|
217
|
+
{ value: x.mean }
|
218
|
+
end
|
219
|
+
|
220
|
+
def fit_median(x)
|
221
|
+
{ value: x.median }
|
222
|
+
end
|
223
|
+
|
224
|
+
def fit_ffill(x, df)
|
225
|
+
values = { value: nil, max_date: nil }
|
226
|
+
|
227
|
+
date_col = df[options[:date_column]]
|
228
|
+
return if date_col.is_null.all
|
229
|
+
|
230
|
+
sorted_df = df.sort(options[:date_column])
|
231
|
+
new_max_date = sorted_df[options[:date_column]].max
|
232
|
+
|
233
|
+
current_max_date = values[:max_date]
|
234
|
+
return if current_max_date && current_max_date > new_max_date
|
235
|
+
|
236
|
+
values[:max_date] = [current_max_date, new_max_date].compact.max
|
237
|
+
|
238
|
+
# Get the last non-null value
|
239
|
+
last_non_null = sorted_df[x.name].filter(sorted_df[x.name].is_not_null).tail(1).to_a.first
|
240
|
+
values[:value] = last_non_null
|
241
|
+
|
242
|
+
values
|
243
|
+
end
|
244
|
+
|
245
|
+
def fit_most_frequent(x)
|
246
|
+
value_counts = x.filter(x.is_not_null).value_counts
|
247
|
+
column_names = value_counts.columns
|
248
|
+
column_names[0]
|
249
|
+
count_column = column_names[1]
|
250
|
+
|
251
|
+
most_frequent_value = value_counts.sort(count_column, descending: true).row(0)[0]
|
252
|
+
{ value: most_frequent_value }
|
253
|
+
end
|
254
|
+
|
255
|
+
def fit_categorical(x)
|
256
|
+
value_counts = x.value_counts
|
257
|
+
column_names = value_counts.columns
|
258
|
+
value_column = column_names[0]
|
259
|
+
count_column = column_names[1]
|
260
|
+
|
261
|
+
as_hash = value_counts.select([value_column, count_column]).rows.to_a.to_h.transform_keys(&:to_s)
|
262
|
+
label_encoder = as_hash.keys.sort.each.with_index.reduce({}) do |h, (k, i)|
|
263
|
+
h.tap do
|
264
|
+
h[k] = i
|
265
|
+
end
|
266
|
+
end
|
267
|
+
label_decoder = label_encoder.invert
|
268
|
+
|
269
|
+
{
|
270
|
+
value: as_hash,
|
271
|
+
label_encoder: label_encoder,
|
272
|
+
label_decoder: label_decoder
|
273
|
+
}
|
274
|
+
end
|
275
|
+
|
276
|
+
def fit_no_op(_x)
|
277
|
+
{}
|
278
|
+
end
|
279
|
+
|
280
|
+
def fit_constant(_x)
|
281
|
+
{ value: @options[:fill_value] }
|
282
|
+
end
|
283
|
+
|
284
|
+
def transform_default(_val)
|
285
|
+
@statistics[strategy][:value]
|
286
|
+
end
|
287
|
+
|
288
|
+
def should_transform_default?(val)
|
289
|
+
checker_method = val.respond_to?(:nan?) ? :nan? : :nil?
|
290
|
+
val.send(checker_method)
|
291
|
+
end
|
292
|
+
|
293
|
+
def transform_dense(x)
|
294
|
+
result = x.map do |val|
|
295
|
+
strategy_method = respond_to?("transform_#{strategy}") ? "transform_#{strategy}" : "transform_default"
|
296
|
+
checker_method = respond_to?("should_transform_#{strategy}?") ? "should_transform_#{strategy}?" : "should_transform_default?"
|
297
|
+
send(checker_method, val) ? send(strategy_method, val) : val
|
298
|
+
end
|
299
|
+
|
300
|
+
# Cast the result back to the original dtype
|
301
|
+
original_dtype = @statistics[:original_dtype]
|
302
|
+
if original_dtype
|
303
|
+
result.map { |val| cast_to_dtype(val, original_dtype) }
|
304
|
+
else
|
305
|
+
result
|
306
|
+
end
|
307
|
+
end
|
308
|
+
|
309
|
+
def check_is_fitted
|
310
|
+
return if %i[clip today custom].include?(strategy)
|
311
|
+
|
312
|
+
raise "SimpleImputer has not been fitted yet for #{attribute}##{strategy}" unless @statistics[strategy]
|
313
|
+
end
|
314
|
+
|
315
|
+
def serialize_statistics(stats)
|
316
|
+
stats.deep_transform_values do |value|
|
317
|
+
case value
|
318
|
+
when Time, DateTime
|
319
|
+
{ "__type__" => "datetime", "value" => value.iso8601 }
|
320
|
+
when Date
|
321
|
+
{ "__type__" => "date", "value" => value.iso8601 }
|
322
|
+
when BigDecimal
|
323
|
+
{ "__type__" => "bigdecimal", "value" => value.to_s }
|
324
|
+
when Polars::DataType
|
325
|
+
{ "__type__" => "polars_dtype", "value" => value.to_s }
|
326
|
+
when Symbol
|
327
|
+
{ "__type__" => "symbol", "value" => value.to_s }
|
328
|
+
else
|
329
|
+
value
|
330
|
+
end
|
331
|
+
end
|
332
|
+
end
|
333
|
+
|
334
|
+
def deserialize_statistics(stats)
|
335
|
+
stats.transform_values do |value|
|
336
|
+
recursive_deserialize(value)
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
def recursive_deserialize(value)
|
341
|
+
case value
|
342
|
+
when Hash
|
343
|
+
if value["__type__"]
|
344
|
+
deserialize_special_type(value)
|
345
|
+
else
|
346
|
+
value.transform_values { |v| recursive_deserialize(v) }
|
347
|
+
end
|
348
|
+
when Array
|
349
|
+
value.map { |v| recursive_deserialize(v) }
|
350
|
+
else
|
351
|
+
value
|
352
|
+
end
|
353
|
+
end
|
354
|
+
|
355
|
+
def deserialize_special_type(value)
|
356
|
+
case value["__type__"]
|
357
|
+
when "datetime"
|
358
|
+
DateTime.parse(value["value"])
|
359
|
+
when "date"
|
360
|
+
Date.parse(value["value"])
|
361
|
+
when "bigdecimal"
|
362
|
+
BigDecimal(value["value"])
|
363
|
+
when "polars_dtype"
|
364
|
+
parse_polars_dtype(value["value"])
|
365
|
+
when "symbol"
|
366
|
+
value["value"].to_sym
|
367
|
+
else
|
368
|
+
value["value"]
|
369
|
+
end
|
370
|
+
end
|
371
|
+
|
372
|
+
def parse_polars_dtype(dtype_string)
|
373
|
+
case dtype_string
|
374
|
+
when /^Polars::Datetime/
|
375
|
+
time_unit = dtype_string[/time_unit: "(.*?)"/, 1]
|
376
|
+
time_zone = dtype_string[/time_zone: (.*)?\)/, 1]
|
377
|
+
time_zone = time_zone == "nil" ? nil : time_zone&.delete('"')
|
378
|
+
Polars::Datetime.new(time_unit: time_unit, time_zone: time_zone).class
|
379
|
+
when /^Polars::/
|
380
|
+
Polars.const_get(dtype_string.split("::").last)
|
381
|
+
else
|
382
|
+
raise ArgumentError, "Unknown Polars data type: #{dtype_string}"
|
383
|
+
end
|
384
|
+
end
|
385
|
+
|
386
|
+
def cast_to_dtype(value, dtype)
|
387
|
+
case dtype
|
388
|
+
when Polars::Int64
|
389
|
+
value.to_i
|
390
|
+
when Polars::Float64
|
391
|
+
value.to_f
|
392
|
+
when Polars::Boolean
|
393
|
+
!!value
|
394
|
+
when Polars::Utf8
|
395
|
+
value.to_s
|
396
|
+
else
|
397
|
+
value
|
398
|
+
end
|
399
|
+
end
|
400
|
+
end
|
401
|
+
end
|
402
|
+
end
|
403
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module EasyML::Data
|
2
|
+
class Preprocessor
|
3
|
+
module Utils
|
4
|
+
def standardize_config(config)
|
5
|
+
config.each do |column, strategies|
|
6
|
+
next unless strategies.is_a?(Array)
|
7
|
+
|
8
|
+
config[column] = strategies.reduce({}) do |hash, strategy|
|
9
|
+
hash.tap do
|
10
|
+
hash[strategy] = true
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,238 @@
|
|
1
|
+
require "fileutils"
|
2
|
+
require "polars"
|
3
|
+
require "date"
|
4
|
+
require "json"
|
5
|
+
require_relative "preprocessor/utils"
|
6
|
+
require_relative "preprocessor/simple_imputer"
|
7
|
+
|
8
|
+
module EasyML::Data
|
9
|
+
class Preprocessor
|
10
|
+
include GlueGun::DSL
|
11
|
+
include EasyML::Data::Preprocessor::Utils
|
12
|
+
|
13
|
+
CATEGORICAL_COMMON_MIN = 50
|
14
|
+
PREPROCESSING_ORDER = %w[clip mean median constant categorical one_hot ffill custom fill_date add_datepart]
|
15
|
+
|
16
|
+
attribute :directory, :string
|
17
|
+
attribute :verbose, :boolean, default: false
|
18
|
+
attribute :preprocessing_steps, :hash, default: {}
|
19
|
+
def preprocessing_steps=(preprocessing_steps)
|
20
|
+
super(standardize_config(preprocessing_steps).with_indifferent_access)
|
21
|
+
end
|
22
|
+
|
23
|
+
def fit(df)
|
24
|
+
return if df.nil?
|
25
|
+
return if preprocessing_steps.keys.none?
|
26
|
+
|
27
|
+
puts "Preprocessing..." if verbose
|
28
|
+
imputers = initialize_imputers(
|
29
|
+
preprocessing_steps[:training].merge!(preprocessing_steps[:inference] || {})
|
30
|
+
)
|
31
|
+
|
32
|
+
did_cleanup = false
|
33
|
+
imputers.each do |col, imputers|
|
34
|
+
sorted_strategies(imputers).each do |strategy|
|
35
|
+
imputer = imputers[strategy]
|
36
|
+
unless did_cleanup
|
37
|
+
imputer.cleanup
|
38
|
+
did_cleanup = true
|
39
|
+
end
|
40
|
+
if df.columns.map(&:downcase).include?(col.downcase)
|
41
|
+
actual_col = df.columns.find { |c| c.downcase == imputer.attribute.downcase }
|
42
|
+
imputer.fit(df[actual_col], df)
|
43
|
+
if strategy == "clip" # This is the only one to transform during fit
|
44
|
+
df[actual_col] = imputer.transform(df[actual_col])
|
45
|
+
end
|
46
|
+
elsif @verbose
|
47
|
+
puts "Warning: Column '#{col}' not found in DataFrame during fit process."
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def postprocess(df, inference: false)
|
54
|
+
puts "Postprocessing..." if verbose
|
55
|
+
return df if preprocessing_steps.keys.none?
|
56
|
+
|
57
|
+
steps = if inference
|
58
|
+
preprocessing_steps[:training].merge(preprocessing_steps[:inference] || {})
|
59
|
+
else
|
60
|
+
preprocessing_steps[:training]
|
61
|
+
end
|
62
|
+
|
63
|
+
df = apply_transformations(df, steps)
|
64
|
+
|
65
|
+
puts "Postprocessing complete." if @verbose
|
66
|
+
df
|
67
|
+
end
|
68
|
+
|
69
|
+
def statistics
|
70
|
+
initialize_imputers(preprocessing_steps[:training]).each_with_object({}) do |(col, strategies), result|
|
71
|
+
result[col] = strategies.each_with_object({}) do |(strategy, imputer), col_result|
|
72
|
+
col_result[strategy] = imputer.statistics
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
def is_fit?
|
78
|
+
statistics.any? { |_, col_stats| col_stats.any? { |_, strategy_stats| strategy_stats.present? } }
|
79
|
+
end
|
80
|
+
|
81
|
+
def delete
|
82
|
+
return unless File.directory?(@directory)
|
83
|
+
|
84
|
+
FileUtils.rm_rf(@directory)
|
85
|
+
end
|
86
|
+
|
87
|
+
def move(to)
|
88
|
+
old_dir = directory
|
89
|
+
current_env = directory.split("/")[-1]
|
90
|
+
new_dir = directory.gsub(Regexp.new(current_env), to)
|
91
|
+
|
92
|
+
puts "Moving #{old_dir} to #{new_dir}"
|
93
|
+
FileUtils.mv(old_dir, new_dir)
|
94
|
+
@directory = new_dir
|
95
|
+
end
|
96
|
+
|
97
|
+
def decode_labels(values, col: nil)
|
98
|
+
imputers = initialize_imputers(preprocessing_steps[:training], dumb: true)
|
99
|
+
imputer = imputers.dig(col, "categorical")
|
100
|
+
decoder = imputer.statistics.dig(:categorical, :label_decoder)
|
101
|
+
|
102
|
+
other_value = decoder.keys.map(&:to_s).map(&:to_i).max + 1
|
103
|
+
decoder[other_value] = "other"
|
104
|
+
decoder.stringify_keys!
|
105
|
+
|
106
|
+
values.map do |value|
|
107
|
+
decoder[value.to_s]
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
private
|
112
|
+
|
113
|
+
def initialize_imputers(config, dumb: false)
|
114
|
+
standardize_config(config).each_with_object({}) do |(col, strategies), hash|
|
115
|
+
hash[col] ||= {}
|
116
|
+
strategies.each do |strategy, options|
|
117
|
+
options = {} if options == true
|
118
|
+
|
119
|
+
hash[col][strategy] = EasyML::Data::Preprocessor::SimpleImputer.new(
|
120
|
+
strategy: strategy,
|
121
|
+
path: directory,
|
122
|
+
attribute: col,
|
123
|
+
options: options
|
124
|
+
)
|
125
|
+
end
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
def apply_transformations(df, config)
|
130
|
+
imputers = initialize_imputers(config)
|
131
|
+
|
132
|
+
standardize_config(config).each do |col, strategies|
|
133
|
+
if df.columns.map(&:downcase).include?(col.downcase)
|
134
|
+
actual_col = df.columns.find { |c| c.downcase == col.downcase }
|
135
|
+
|
136
|
+
sorted_strategies(strategies).each do |strategy|
|
137
|
+
if strategy.to_sym == :categorical
|
138
|
+
if imputers.dig(col, strategy).options.dig("one_hot")
|
139
|
+
df = apply_one_hot(df, col, imputers)
|
140
|
+
elsif imputers.dig(col, strategy).options.dig("encode_labels")
|
141
|
+
df = apply_encode_labels(df, col, imputers)
|
142
|
+
end
|
143
|
+
else
|
144
|
+
imputer = imputers.dig(col, strategy)
|
145
|
+
df[actual_col] = imputer.transform(df[actual_col]) if imputer
|
146
|
+
end
|
147
|
+
end
|
148
|
+
elsif @verbose
|
149
|
+
puts "Warning: Column '#{col}' not found in DataFrame during apply_transformations process."
|
150
|
+
end
|
151
|
+
end
|
152
|
+
|
153
|
+
df
|
154
|
+
end
|
155
|
+
|
156
|
+
def apply_one_hot(df, col, imputers)
|
157
|
+
cat_imputer = imputers.dig(col, "categorical")
|
158
|
+
approved_values = cat_imputer.statistics[:categorical][:value].select do |_k, v|
|
159
|
+
v >= cat_imputer.options["categorical_min"]
|
160
|
+
end.keys
|
161
|
+
|
162
|
+
# Create one-hot encoded columns
|
163
|
+
approved_values.each do |value|
|
164
|
+
new_col_name = "#{col}_#{value}".tr("-", "_")
|
165
|
+
df = df.with_column(
|
166
|
+
df[col].eq(value.to_s).cast(Polars::Int64).alias(new_col_name)
|
167
|
+
)
|
168
|
+
end
|
169
|
+
|
170
|
+
# Create 'other' column for unapproved values
|
171
|
+
other_col_name = "#{col}_other"
|
172
|
+
df[other_col_name] = df[col].map_elements do |value|
|
173
|
+
approved_values.map(&:to_s).exclude?(value)
|
174
|
+
end.cast(Polars::Int64)
|
175
|
+
df.drop([col])
|
176
|
+
end
|
177
|
+
|
178
|
+
def apply_encode_labels(df, col, imputers)
|
179
|
+
cat_imputer = imputers.dig(col, "categorical")
|
180
|
+
approved_values = cat_imputer.statistics[:categorical][:value].select do |_k, v|
|
181
|
+
v >= cat_imputer.options["categorical_min"]
|
182
|
+
end.keys
|
183
|
+
|
184
|
+
df.with_column(
|
185
|
+
df[col].map_elements do |value|
|
186
|
+
approved_values.map(&:to_s).exclude?(value) ? "other" : value
|
187
|
+
end.alias(col)
|
188
|
+
)
|
189
|
+
|
190
|
+
label_encoder = cat_imputer.statistics[:categorical][:label_encoder].stringify_keys
|
191
|
+
other_value = label_encoder.values.max + 1
|
192
|
+
label_encoder["other"] = other_value
|
193
|
+
|
194
|
+
df.with_column(
|
195
|
+
df[col].map { |v| label_encoder[v.to_s] }.alias(col)
|
196
|
+
)
|
197
|
+
end
|
198
|
+
|
199
|
+
def sorted_strategies(strategies)
|
200
|
+
strategies.keys.sort_by do |key|
|
201
|
+
PREPROCESSING_ORDER.index(key)
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def prepare_for_imputation(df, col)
|
206
|
+
df = df.with_column(Polars.col(col).cast(Polars::Float64))
|
207
|
+
df.with_column(Polars.when(Polars.col(col).is_null).then(Float::NAN).otherwise(Polars.col(col)).alias(col))
|
208
|
+
end
|
209
|
+
end
|
210
|
+
end
|
211
|
+
|
212
|
+
# Where to put this???
|
213
|
+
#
|
214
|
+
# def self.stage_required_files
|
215
|
+
# required_files.each do |file|
|
216
|
+
# git_add(file)
|
217
|
+
# end
|
218
|
+
# end
|
219
|
+
|
220
|
+
# def self.git_add(path)
|
221
|
+
# command = "git add #{path}"
|
222
|
+
# puts command if verbose
|
223
|
+
# result = `#{command}`
|
224
|
+
# puts result if verbose
|
225
|
+
# end
|
226
|
+
|
227
|
+
# def self.set_verbose(verbose)
|
228
|
+
# @verbose = verbose
|
229
|
+
# end
|
230
|
+
|
231
|
+
# def required_files
|
232
|
+
# files = Dir.entries(@directory) - %w[. ..]
|
233
|
+
# required_file_types = %w[bin]
|
234
|
+
|
235
|
+
# files.select { |file| required_file_types.any? { |ext| file.include?(ext) } }.map do |file|
|
236
|
+
# File.join(@directory, file)
|
237
|
+
# end
|
238
|
+
# end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
module EasyML
|
2
|
+
module Data
|
3
|
+
module Utils
|
4
|
+
def append_to_csv(df, path)
|
5
|
+
return if df.empty?
|
6
|
+
|
7
|
+
path = Pathname.new(path) if path.is_a?(String)
|
8
|
+
FileUtils.mkdir_p(path.dirname)
|
9
|
+
FileUtils.touch(path)
|
10
|
+
|
11
|
+
# Check if the file is empty (i.e., if this is the first write)
|
12
|
+
file_empty = File.zero?(path)
|
13
|
+
|
14
|
+
# Write the DataFrame to a temporary file
|
15
|
+
temp_file = "#{path}.tmp"
|
16
|
+
df.write_csv(temp_file)
|
17
|
+
|
18
|
+
# Append the content to the main file, skipping the header if not the first write
|
19
|
+
File.open(path, "a") do |f|
|
20
|
+
File.foreach(temp_file).with_index do |line, index|
|
21
|
+
# Skip the header line if the file is not empty
|
22
|
+
f.write(line) unless index == 0 && !file_empty
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Delete the temporary file
|
27
|
+
File.delete(temp_file)
|
28
|
+
end
|
29
|
+
|
30
|
+
def expand_dir(dir)
|
31
|
+
return dir if dir.to_s[0] == "/"
|
32
|
+
|
33
|
+
Rails.root.join(dir)
|
34
|
+
end
|
35
|
+
|
36
|
+
def null_check(df)
|
37
|
+
result = {}
|
38
|
+
null_counts = df.null_count
|
39
|
+
total_count = df.height
|
40
|
+
df.columns.each do |column|
|
41
|
+
null_count = null_counts[column][0]
|
42
|
+
next if null_count == 0
|
43
|
+
|
44
|
+
result[column] = { null_count: null_count, total_count: total_count }
|
45
|
+
end
|
46
|
+
result.empty? ? nil : result
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|