easy_ml 0.2.0.pre.rc52 → 0.2.0.pre.rc55

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b0b6da194500895ff325b1408757ee328b355da448a5fcc53439d7b14ccea9c8
4
- data.tar.gz: da221df414245aafd56c5a05367146b9aae131885a99fa19e11aaa0f3cd32a20
3
+ metadata.gz: 53b0666a5ba25b758573564fb916da9ec5b968223e0b8164e394c6b2d176668a
4
+ data.tar.gz: 5180eb72880c443fd5405bba8127ab2ba9c891f51c7a7063f47862a48d10ca12
5
5
  SHA512:
6
- metadata.gz: 1d04422529423e09cff72496a29ff41038fdde5adfc0f267065981d04c24c9e9475d9ad0b2ab083e2d5e8cef22581180df97c93e960579cca8657b1771adc977
7
- data.tar.gz: d9904acd55317dc68d774c26f392399990fdd52f45f394f8fdc6addf11456ea1361e2d972a438647c7dc0f679e3642cd8ae1f13cc9e0824c72865de8ecb0b6a9
6
+ metadata.gz: 912c3eba2b15b1d463eb53a2e7cd689fec14c45cd58fb232a5f31114ea35f883004616977598d85b8b2e22bbd0ac19dc4e5f294878cc544007eaed1ffd61bd0d
7
+ data.tar.gz: b4b89abaf966fc0dc8f99e58feb80503175a97c8006ccd9d0bd17319ebdf1b1d1c615c762f159a7ea2b167abd9f0351e0cf2e4e3a7bc34694067f011a03aae0d
@@ -52,7 +52,7 @@ module EasyML
52
52
 
53
53
  flash_messages << { type: "success", message: flash[:notice] } if flash[:notice]
54
54
 
55
- flash_messages << { type: "error", message: flash[:alert] } if flash[:alert]
55
+ flash_messages << { type: "error", message: flash[:error] } if flash[:error]
56
56
 
57
57
  flash_messages << { type: "info", message: flash[:info] } if flash[:info]
58
58
 
@@ -13,7 +13,7 @@ module EasyML
13
13
  flash[:notice] = "Model deployment has started"
14
14
  redirect_to easy_ml_model_path(@deploy.model)
15
15
  rescue => e
16
- flash[:alert] = "Trouble deploying model: #{e.message}"
16
+ flash[:error] = "Trouble deploying model: #{e.message}"
17
17
  redirect_to easy_ml_model_path(@deploy.model)
18
18
  end
19
19
  end
@@ -70,6 +70,9 @@ module EasyML
70
70
  flash[:notice] = "Model was successfully updated."
71
71
  redirect_to easy_ml_models_path
72
72
  else
73
+ errors = model.errors.to_hash(true)
74
+ values = errors.values.flatten
75
+ flash.now[:error] = values.join(", ")
73
76
  render inertia: "pages/EditModelPage", props: {
74
77
  model: model_to_json(model),
75
78
  datasets: EasyML::Dataset.all.map { |dataset| dataset_to_json(dataset) },
@@ -99,7 +102,7 @@ module EasyML
99
102
  flash[:notice] = "Model was successfully deleted."
100
103
  redirect_to easy_ml_models_path
101
104
  else
102
- flash[:alert] = "Failed to delete the model."
105
+ flash[:error] = "Failed to delete the model."
103
106
  redirect_to easy_ml_models_path
104
107
  end
105
108
  end
@@ -38,9 +38,12 @@ export function AlertProvider({ children }: { children: React.ReactNode }) {
38
38
  const id = Math.random().toString(36).substring(7);
39
39
  setAlerts(prev => [...prev, { id, type, message }]);
40
40
 
41
- setTimeout(() => {
42
- removeAlert(id);
43
- }, numSeconds * 1000);
41
+ // Only auto-dismiss non-error alerts
42
+ if (type !== 'error') {
43
+ setTimeout(() => {
44
+ removeAlert(id);
45
+ }, numSeconds * 1000);
46
+ }
44
47
  }, [removeAlert]);
45
48
 
46
49
  return (
@@ -103,29 +103,6 @@ export function ModelForm({ initialData, datasets, constants, isEditing, errors:
103
103
  const objectives: { value: string; label: string; description?: string }[] =
104
104
  constants.objectives[data.model.model_type]?.[data.model.task] || [];
105
105
 
106
- useEffect(() => {
107
- // Only set default metrics if none were provided from the backend
108
- if (!initialData?.metrics) {
109
- const availableMetrics = constants.metrics[data.model.task]?.map(metric => metric.value) || [];
110
- setData({
111
- ...data,
112
- model: {
113
- ...data.model,
114
- objective: data.model.task === 'classification' ? 'binary:logistic' : 'reg:squarederror',
115
- metrics: availableMetrics
116
- }
117
- });
118
- } else {
119
- setData({
120
- ...data,
121
- model: {
122
- ...data.model,
123
- objective: data.model.task === 'classification' ? 'binary:logistic' : 'reg:squarederror'
124
- }
125
- });
126
- }
127
- }, [data.model.task]);
128
-
129
106
  useEffect(() => {
130
107
  if (isDataSet) {
131
108
  save();
@@ -187,11 +164,21 @@ export function ModelForm({ initialData, datasets, constants, isEditing, errors:
187
164
  save();
188
165
  };
189
166
 
190
- console.log(data.model)
191
167
  const selectedDataset = datasets.find(d => d.id === data.model.dataset_id);
192
168
 
193
169
  const filteredTunerJobConstants = constants.tuner_job_constants[data.model.model_type] || {};
194
170
 
171
+ const handleTaskChange = (value: string) => {
172
+ // First update the task
173
+ setData('model.task', value);
174
+
175
+ // Then force reset metrics to empty array
176
+ setData('model.metrics', []);
177
+
178
+ // Update objective based on new task
179
+ setData('model.objective', value === 'classification' ? 'binary:logistic' : 'reg:squarederror');
180
+ };
181
+
195
182
  return (
196
183
  <form onSubmit={handleSubmit} className="space-y-8">
197
184
  <div className="flex justify-between items-center border-b pb-4">
@@ -266,8 +253,7 @@ export function ModelForm({ initialData, datasets, constants, isEditing, errors:
266
253
  <SearchableSelect
267
254
  options={constants.tasks}
268
255
  value={data.model.task}
269
- onChange={(value) => setData('model.task', value as string)}
270
- placeholder="Select task"
256
+ onChange={handleTaskChange}
271
257
  />
272
258
  <ErrorDisplay error={errors.task} />
273
259
  </div>
@@ -300,24 +286,21 @@ export function ModelForm({ initialData, datasets, constants, isEditing, errors:
300
286
  type="checkbox"
301
287
  checked={data.model.metrics.includes(metric.value)}
302
288
  onChange={(e) => {
303
- const metrics = e.target.checked
289
+ const newMetrics = e.target.checked
304
290
  ? [...data.model.metrics, metric.value]
305
291
  : data.model.metrics.filter(m => m !== metric.value);
306
- setData('model.metrics', metrics);
292
+ setData('model.metrics', newMetrics);
307
293
  }}
308
- className="h-4 w-4 rounded border-gray-300 text-blue-600 focus:ring-blue-500"
294
+ className="h-4 w-4 text-blue-600 focus:ring-blue-500 border-gray-300 rounded"
309
295
  />
310
296
  <div className="ml-3">
311
- <span className="block text-sm font-medium text-gray-900">
312
- {metric.label}
313
- </span>
314
- <span className="block text-xs text-gray-500">
315
- {metric.direction === 'maximize' ? 'Higher is better' : 'Lower is better'}
316
- </span>
297
+ <span className="block text-sm font-medium text-gray-900">{metric.label}</span>
298
+ <span className="block text-xs text-gray-500">Direction: {metric.direction}</span>
317
299
  </div>
318
300
  </label>
319
301
  ))}
320
302
  </div>
303
+ <ErrorDisplay error={errors.metrics} />
321
304
  </div>
322
305
  </div>
323
306
 
@@ -374,6 +357,7 @@ export function ModelForm({ initialData, datasets, constants, isEditing, errors:
374
357
  dataset: selectedDataset,
375
358
  retraining_job: data.model.retraining_job_attributes
376
359
  }}
360
+ metrics={constants.metrics}
377
361
  tunerJobConstants={filteredTunerJobConstants}
378
362
  timezone={constants.timezone}
379
363
  retrainingJobConstants={constants.retraining_job_constants}
@@ -37,27 +37,20 @@ interface ScheduleModalProps {
37
37
  tuning_enabled?: boolean;
38
38
  };
39
39
  };
40
+ metrics: {
41
+ [key: string]: Array<{
42
+ value: string;
43
+ label: string;
44
+ description: string;
45
+ direction: string;
46
+ }>;
47
+ };
40
48
  tunerJobConstants: any;
41
49
  timezone: string;
42
50
  retrainingJobConstants: any;
43
51
  }
44
52
 
45
- const METRICS = {
46
- classification: [
47
- { value: 'accuracy_score', label: 'Accuracy', description: 'Overall prediction accuracy', direction: 'maximize' },
48
- { value: 'precision_score', label: 'Precision', description: 'Ratio of true positives to predicted positives', direction: 'maximize' },
49
- { value: 'recall_score', label: 'Recall', description: 'Ratio of true positives to actual positives', direction: 'maximize' },
50
- { value: 'f1_score', label: 'F1 Score', description: 'Harmonic mean of precision and recall', direction: 'maximize' }
51
- ],
52
- regression: [
53
- { value: 'mean_absolute_error', label: 'Mean Absolute Error', description: 'Average absolute differences between predicted and actual values', direction: 'minimize' },
54
- { value: 'mean_squared_error', label: 'Mean Squared Error', description: 'Average squared differences between predicted and actual values', direction: 'minimize' },
55
- { value: 'root_mean_squared_error', label: 'Root Mean Squared Error', description: 'Square root of mean squared error', direction: 'minimize' },
56
- { value: 'r2_score', label: 'R² Score', description: 'Proportion of variance in the target that is predictable', direction: 'maximize' }
57
- ]
58
- };
59
-
60
- export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobConstants, timezone, retrainingJobConstants }: ScheduleModalProps) {
53
+ export function ScheduleModal({ isOpen, onClose, onSave, initialData, metrics, tunerJobConstants, timezone, retrainingJobConstants }: ScheduleModalProps) {
61
54
  const [showBatchTrainingInfo, setShowBatchTrainingInfo] = useState(false);
62
55
  const [activeBatchPopover, setActiveBatchPopover] = useState<'size' | 'overlap' | null>(null);
63
56
 
@@ -97,7 +90,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
97
90
  day_of_week: initialData.retraining_job?.at?.day_of_week ?? 1,
98
91
  day_of_month: initialData.retraining_job?.at?.day_of_month ?? 1
99
92
  },
100
- metric: initialData.retraining_job?.metric || METRICS[initialData.task === 'classification' ? 'classification' : 'regression'][0].value,
93
+ metric: initialData.retraining_job?.metric || (metrics[initialData.task]?.[0]?.value ?? ''),
101
94
  threshold: initialData.retraining_job?.threshold || (initialData.task === 'classification' ? 0.85 : 0.1),
102
95
  tuner_config: initialData.retraining_job?.tuner_config ? {
103
96
  n_trials: initialData.retraining_job.tuner_config.n_trials || 10,
@@ -336,9 +329,9 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
336
329
  };
337
330
 
338
331
  return (
339
- <div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50">
340
- <div className="bg-white rounded-lg w-full max-w-6xl max-h-[90vh] overflow-hidden">
341
- <div className="flex justify-between items-center p-4 border-b">
332
+ <div className="fixed inset-0 bg-black bg-opacity-50 flex items-start justify-center pt-[5vh] z-50">
333
+ <div className="bg-white rounded-lg w-full max-w-6xl flex flex-col" style={{ maxHeight: '90vh' }}>
334
+ <div className="flex-none flex justify-between items-center p-4 border-b">
342
335
  <h2 className="text-lg font-semibold">Training Configuration</h2>
343
336
  <button
344
337
  onClick={onClose}
@@ -348,7 +341,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
348
341
  </button>
349
342
  </div>
350
343
 
351
- <div className="p-6 grid grid-cols-2 gap-8 max-h-[calc(90vh-8rem)] overflow-y-auto">
344
+ <div className="flex-1 p-6 grid grid-cols-2 gap-8 overflow-y-auto">
352
345
  {/* Left Column */}
353
346
  <div className="space-y-8">
354
347
  {/* Training Schedule */}
@@ -575,7 +568,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
575
568
  Metric
576
569
  </label>
577
570
  <SearchableSelect
578
- options={METRICS[initialData.task === 'classification' ? 'classification' : 'regression'].map((metric) => ({
571
+ options={metrics[initialData.task].map((metric) => ({
579
572
  value: metric.value,
580
573
  label: metric.label,
581
574
  description: metric.description
@@ -610,8 +603,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
610
603
  <h3 className="text-sm font-medium text-blue-800">Deployment Criteria</h3>
611
604
  <p className="mt-2 text-sm text-blue-700">
612
605
  {(() => {
613
- const metricsList = METRICS[initialData.task === 'classification' ? 'classification' : 'regression'];
614
- const selectedMetric = metricsList.find(m => m.value === formData.retraining_job_attributes.metric);
606
+ const selectedMetric = metrics[initialData.task].find(m => m.value === formData.retraining_job_attributes.metric);
615
607
  const direction = selectedMetric?.direction === 'minimize' ? 'below' : 'above';
616
608
 
617
609
  return `The model will be automatically deployed when the ${selectedMetric?.label} is ${direction} ${formData.retraining_job_attributes.threshold}.`;
@@ -711,7 +703,7 @@ export function ScheduleModal({ isOpen, onClose, onSave, initialData, tunerJobCo
711
703
  </div>
712
704
  </div>
713
705
 
714
- <div className="flex justify-end gap-4 p-4 border-t">
706
+ <div className="flex-none flex justify-end gap-4 p-4 border-t bg-white">
715
707
  <button
716
708
  onClick={onClose}
717
709
  className="px-4 py-2 text-sm font-medium text-gray-700 hover:text-gray-500"
@@ -1,5 +1,6 @@
1
1
  import React, { useState, useRef, useEffect, forwardRef } from 'react';
2
2
  import { Search, Check } from 'lucide-react';
3
+ import { createPortal } from 'react-dom';
3
4
 
4
5
  interface Option {
5
6
  value: string | number;
@@ -20,6 +21,7 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
20
21
  ({ options, value, onChange, placeholder = 'Search...', renderOption }, ref) => {
21
22
  const [isOpen, setIsOpen] = useState(false);
22
23
  const [searchQuery, setSearchQuery] = useState('');
24
+ const [dropdownPosition, setDropdownPosition] = useState({ top: 0, left: 0, width: 0 });
23
25
  const containerRef = useRef<HTMLDivElement>(null);
24
26
  const inputRef = useRef<HTMLInputElement>(null);
25
27
 
@@ -47,11 +49,100 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
47
49
  }
48
50
  }, [isOpen]);
49
51
 
52
+ useEffect(() => {
53
+ if (isOpen && containerRef.current) {
54
+ const rect = containerRef.current.getBoundingClientRect();
55
+ setDropdownPosition({
56
+ top: rect.bottom + window.scrollY,
57
+ left: rect.left + window.scrollX,
58
+ width: rect.width
59
+ });
60
+ }
61
+ }, [isOpen]);
62
+
63
+ const handleOptionClick = (optionValue: Option['value']) => {
64
+ onChange(optionValue);
65
+ setIsOpen(false);
66
+ setSearchQuery('');
67
+ };
68
+
69
+ const dropdown = isOpen && createPortal(
70
+ <div
71
+ className="fixed bg-white shadow-lg rounded-md overflow-hidden border border-gray-200"
72
+ style={{
73
+ top: dropdownPosition.top,
74
+ left: dropdownPosition.left,
75
+ width: dropdownPosition.width,
76
+ zIndex: 9999
77
+ }}
78
+ >
79
+ <div className="p-2 border-b">
80
+ <div className="relative">
81
+ <Search className="absolute left-3 top-1/2 transform -translate-y-1/2 w-4 h-4 text-gray-400" />
82
+ <input
83
+ ref={inputRef}
84
+ type="text"
85
+ className="w-full pl-9 pr-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-1 focus:ring-blue-500 focus:border-blue-500"
86
+ placeholder="Search..."
87
+ value={searchQuery}
88
+ onChange={(e) => setSearchQuery(e.target.value)}
89
+ onClick={(e) => e.stopPropagation()}
90
+ />
91
+ </div>
92
+ </div>
93
+
94
+ <div className="max-h-60 overflow-y-auto">
95
+ {filteredOptions.length === 0 ? (
96
+ <div className="text-center py-4 text-sm text-gray-500">
97
+ No results found
98
+ </div>
99
+ ) : (
100
+ <ul className="py-1">
101
+ {filteredOptions.map((option) => (
102
+ <li key={option.value}>
103
+ <button
104
+ type="button"
105
+ className={`w-full text-left px-4 py-2 hover:bg-gray-100 ${
106
+ option.value === value ? 'bg-blue-50' : ''
107
+ }`}
108
+ onMouseDown={(e) => {
109
+ e.preventDefault();
110
+ e.stopPropagation();
111
+ handleOptionClick(option.value);
112
+ }}
113
+ >
114
+ <div className="flex items-center justify-between">
115
+ <span className="block font-medium">
116
+ {option.label}
117
+ </span>
118
+ {option.value === value && (
119
+ <Check className="w-4 h-4 text-blue-600" />
120
+ )}
121
+ </div>
122
+ {option.description && (
123
+ <span className="block text-sm text-gray-500">
124
+ {option.description}
125
+ </span>
126
+ )}
127
+ </button>
128
+ </li>
129
+ ))}
130
+ </ul>
131
+ )}
132
+ </div>
133
+ </div>,
134
+ document.body
135
+ );
136
+
50
137
  return (
51
138
  <div className="relative" ref={containerRef}>
52
139
  <button
53
140
  type="button"
54
- onClick={() => setIsOpen(!isOpen)}
141
+ onMouseDown={(e) => {
142
+ e.preventDefault();
143
+ e.stopPropagation();
144
+ setIsOpen(!isOpen);
145
+ }}
55
146
  className="w-full bg-white relative border border-gray-300 rounded-md shadow-sm pl-3 pr-10 py-2 text-left cursor-pointer focus:outline-none focus:ring-1 focus:ring-blue-500 focus:border-blue-500"
56
147
  ref={ref}
57
148
  >
@@ -61,72 +152,8 @@ export const SearchableSelect = forwardRef<HTMLButtonElement, SearchableSelectPr
61
152
  <span className="block truncate text-gray-500">{placeholder}</span>
62
153
  )}
63
154
  </button>
64
-
65
- {isOpen && (
66
- <div className="absolute z-10 mt-1 w-full bg-white shadow-lg max-h-96 rounded-md overflow-hidden">
67
- <div className="p-2 border-b">
68
- <div className="relative">
69
- <Search className="absolute left-3 top-1/2 transform -translate-y-1/2 w-4 h-4 text-gray-400" />
70
- <input
71
- ref={inputRef}
72
- type="text"
73
- className="w-full pl-9 pr-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-1 focus:ring-blue-500 focus:border-blue-500"
74
- placeholder="Search..."
75
- value={searchQuery}
76
- onChange={(e) => setSearchQuery(e.target.value)}
77
- onClick={(e) => e.stopPropagation()}
78
- />
79
- </div>
80
- </div>
81
-
82
- <div className="max-h-60 overflow-y-auto">
83
- {filteredOptions.length === 0 ? (
84
- <div className="text-center py-4 text-sm text-gray-500">
85
- No results found
86
- </div>
87
- ) : (
88
- <ul className="py-1">
89
- {filteredOptions.map((option) => (
90
- <li key={option.value}>
91
- <button
92
- type="button"
93
- className={`w-full text-left px-4 py-2 hover:bg-gray-100 ${
94
- option.value === value ? 'bg-blue-50' : ''
95
- }`}
96
- onClick={() => {
97
- onChange(option.value);
98
- setIsOpen(false);
99
- setSearchQuery('');
100
- }}
101
- >
102
- {renderOption ? (
103
- renderOption(option)
104
- ) : (
105
- <div className="flex items-center justify-between">
106
- <div>
107
- <div className="font-medium">{option.label}</div>
108
- {option.description && (
109
- <div className="text-sm text-gray-500">
110
- {option.description}
111
- </div>
112
- )}
113
- </div>
114
- {option.value === value && (
115
- <Check className="w-4 h-4 text-blue-600" />
116
- )}
117
- </div>
118
- )}
119
- </button>
120
- </li>
121
- ))}
122
- </ul>
123
- )}
124
- </div>
125
- </div>
126
- )}
155
+ {dropdown}
127
156
  </div>
128
157
  );
129
158
  }
130
- );
131
-
132
- SearchableSelect.displayName = 'SearchableSelect';
159
+ );
@@ -32,6 +32,7 @@ module EasyML
32
32
  before_save :ensure_valid_datatype
33
33
  after_create :set_date_column_if_date_splitter
34
34
  after_save :handle_date_column_change
35
+ before_save :set_defaults
35
36
 
36
37
  # Scopes
37
38
  scope :visible, -> { where(hidden: false) }
@@ -98,6 +99,84 @@ module EasyML
98
99
 
99
100
  private
100
101
 
102
+ def set_defaults
103
+ self.preprocessing_steps = set_preprocessing_steps_defaults
104
+ end
105
+
106
+ def set_preprocessing_steps_defaults
107
+ preprocessing_steps.inject({}) do |h, (type, config)|
108
+ h.tap do
109
+ h[type] = set_preprocessing_step_defaults(config)
110
+ end
111
+ end
112
+ end
113
+
114
+ ALLOWED_PARAMS = {
115
+ constant: [:constant],
116
+ categorical: %i[categorical_min one_hot ordinal_encoding],
117
+ most_frequent: %i[one_hot ordinal_encoding],
118
+ mean: [:clip],
119
+ median: [:clip],
120
+ }
121
+
122
+ REQUIRED_PARAMS = {
123
+ constant: [:constant],
124
+ categorical: %i[categorical_min one_hot ordinal_encoding],
125
+ }
126
+
127
+ DEFAULT_PARAMS = {
128
+ categorical_min: 1,
129
+ one_hot: true,
130
+ ordinal_encoding: false,
131
+ clip: { min: 0, max: 1_000_000_000 },
132
+ constant: nil,
133
+ }
134
+
135
+ XOR_PARAMS = [{
136
+ params: [:one_hot, :ordinal_encoding],
137
+ default: :one_hot,
138
+ }]
139
+
140
+ def set_preprocessing_step_defaults(config)
141
+ config.deep_symbolize_keys!
142
+ config[:params] ||= {}
143
+ params = config[:params].symbolize_keys
144
+
145
+ required = REQUIRED_PARAMS.fetch(config[:method].to_sym, [])
146
+ allowed = ALLOWED_PARAMS.fetch(config[:method].to_sym, [])
147
+
148
+ missing = required - params.keys
149
+ missing.reject! do |param|
150
+ XOR_PARAMS.any? do |rule|
151
+ if rule[:params].include?(param)
152
+ missing_param = rule[:params].find { |p| p != param }
153
+ params[missing_param] == true
154
+ else
155
+ false
156
+ end
157
+ end
158
+ end
159
+ extra = params.keys - allowed
160
+
161
+ missing.each do |key|
162
+ params[key] = DEFAULT_PARAMS.fetch(key)
163
+ end
164
+
165
+ extra.each do |key|
166
+ params.delete(key)
167
+ end
168
+
169
+ # Only set one of one_hot or ordinal_encoding to true,
170
+ # by default set one_hot to true
171
+ xor = XOR_PARAMS.find { |rule| rule[:params] & params.keys == rule[:params] }
172
+ if xor && xor[:params].all? { |param| params[param] }
173
+ xor[:params].each { |param| params[param] = false }
174
+ params[xor[:default]] = true
175
+ end
176
+
177
+ config.merge!(params: params)
178
+ end
179
+
101
180
  def handle_date_column_change
102
181
  return unless saved_change_to_is_date_column? && is_date_column?
103
182
 
@@ -0,0 +1,41 @@
1
+ module EasyML
2
+ module Evaluators
3
+ class << self
4
+ def register_all
5
+ Dir.glob(Rails.root.join("app/evaluators/**/*.rb")).each { |f| require f }
6
+
7
+ ObjectSpace.each_object(Class).select { |klass|
8
+ klass < EasyML::Evaluators::Base
9
+ }.each do |evaluator_class|
10
+ register_evaluator(evaluator_class)
11
+ end
12
+ end
13
+
14
+ private
15
+
16
+ def register_evaluator(evaluator_class)
17
+ # Convert class name to snake_case for the evaluator name
18
+ # e.g., WeightedMAE becomes weighted_mae
19
+ name = evaluator_class.name.demodulize.titleize.gsub(/Evaluator/, "").strip
20
+
21
+ EasyML::Core::ModelEvaluator.register(
22
+ name,
23
+ evaluator_class,
24
+ get_supported_tasks(evaluator_class),
25
+ [],
26
+ )
27
+ end
28
+
29
+ def get_supported_tasks(evaluator_class)
30
+ if evaluator_class.respond_to?(:supports_task?)
31
+ [:regression, :classification].select { |task| evaluator_class.supports_task?(task) }
32
+ else
33
+ [:regression, :classification] # Default to supporting both if not specified
34
+ end
35
+ end
36
+ end
37
+ end
38
+ end
39
+
40
+ # Register all evaluators when the initializer loads
41
+ EasyML::Evaluators.register_all
@@ -14,15 +14,16 @@ module EasyML
14
14
  key.split("_").join(" ").titleize
15
15
  end
16
16
 
17
- def to_option
18
- EasyML::Option.new(to_h)
17
+ def description
18
+ "No description provided"
19
19
  end
20
20
 
21
21
  def to_h
22
22
  {
23
23
  value: key,
24
24
  label: label,
25
- direction: direction
25
+ direction: direction,
26
+ description: description,
26
27
  }
27
28
  end
28
29
 
@@ -11,6 +11,10 @@ module EasyML
11
11
  y_pred.eq(y_true).count_true.to_f / y_pred.size
12
12
  end
13
13
 
14
+ def description
15
+ "Overall prediction accuracy"
16
+ end
17
+
14
18
  def direction
15
19
  "maximize"
16
20
  end
@@ -29,6 +33,10 @@ module EasyML
29
33
  true_positives.to_f / predicted_positives
30
34
  end
31
35
 
36
+ def description
37
+ "Ratio of true positives to predicted positives"
38
+ end
39
+
32
40
  def direction
33
41
  "maximize"
34
42
  end
@@ -45,6 +53,10 @@ module EasyML
45
53
  true_positives.to_f / actual_positives
46
54
  end
47
55
 
56
+ def description
57
+ "Ratio of true positives to actual positives"
58
+ end
59
+
48
60
  def direction
49
61
  "maximize"
50
62
  end
@@ -61,6 +73,10 @@ module EasyML
61
73
  2 * (precision * recall) / (precision + recall)
62
74
  end
63
75
 
76
+ def description
77
+ "Harmonic mean of precision and recall"
78
+ end
79
+
64
80
  def direction
65
81
  "maximize"
66
82
  end
@@ -104,6 +120,10 @@ module EasyML
104
120
  auc
105
121
  end
106
122
 
123
+ def description
124
+ "Area under the ROC curve"
125
+ end
126
+
107
127
  def direction
108
128
  "maximize"
109
129
  end
@@ -116,6 +136,10 @@ module EasyML
116
136
  AUC.new.evaluate(y_pred: y_pred, y_true: y_true)
117
137
  end
118
138
 
139
+ def description
140
+ "Area under the ROC curve"
141
+ end
142
+
119
143
  def direction
120
144
  "maximize"
121
145
  end