oracle-ads 2.13.17rc0__py3-none-any.whl → 2.13.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/cli.py +7 -5
- ads/aqua/common/entities.py +88 -29
- ads/aqua/common/enums.py +7 -0
- ads/aqua/common/errors.py +5 -0
- ads/aqua/common/utils.py +87 -7
- ads/aqua/constants.py +3 -0
- ads/aqua/extension/deployment_handler.py +36 -0
- ads/aqua/modeldeployment/config_loader.py +10 -0
- ads/aqua/modeldeployment/constants.py +1 -0
- ads/aqua/modeldeployment/deployment.py +99 -22
- ads/aqua/modeldeployment/entities.py +4 -0
- ads/aqua/resources/gpu_shapes_index.json +315 -26
- ads/aqua/shaperecommend/__init__.py +6 -0
- ads/aqua/shaperecommend/constants.py +116 -0
- ads/aqua/shaperecommend/estimator.py +384 -0
- ads/aqua/shaperecommend/llm_config.py +283 -0
- ads/aqua/shaperecommend/recommend.py +493 -0
- ads/aqua/shaperecommend/shape_report.py +233 -0
- ads/aqua/version.json +1 -1
- ads/cli.py +9 -1
- ads/jobs/builders/infrastructure/dsc_job.py +1 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +9 -1
- ads/model/service/oci_datascience_model_deployment.py +46 -19
- ads/opctl/operator/lowcode/common/data.py +7 -2
- ads/opctl/operator/lowcode/common/transformations.py +207 -0
- ads/opctl/operator/lowcode/common/utils.py +8 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +3 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +53 -3
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/errors.py +5 -0
- ads/opctl/operator/lowcode/forecast/meta_selector.py +310 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/base_model.py +119 -30
- ads/opctl/operator/lowcode/forecast/model/factory.py +33 -2
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +54 -17
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +6 -1
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -0
- ads/pipeline/ads_pipeline.py +13 -9
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/METADATA +1 -1
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/RECORD +43 -36
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,17 +3,20 @@
|
|
3
3
|
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
5
|
|
6
|
+
import copy
|
6
7
|
import json
|
7
8
|
import os
|
8
9
|
import sys
|
9
10
|
from typing import Dict, List
|
10
11
|
|
12
|
+
import pandas as pd
|
11
13
|
import yaml
|
12
14
|
|
13
15
|
from ads.opctl import logger
|
14
16
|
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
|
15
17
|
from ads.opctl.operator.common.utils import _parse_input_args
|
16
18
|
|
19
|
+
from .const import AUTO_SELECT_SERIES
|
17
20
|
from .model.forecast_datasets import ForecastDatasets, ForecastResults
|
18
21
|
from .operator_config import ForecastOperatorConfig
|
19
22
|
from .whatifserve import ModelDeploymentManager
|
@@ -24,9 +27,56 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
|
|
24
27
|
from .model.factory import ForecastOperatorModelFactory
|
25
28
|
|
26
29
|
datasets = ForecastDatasets(operator_config)
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
+
model = ForecastOperatorModelFactory.get_model(operator_config, datasets)
|
31
|
+
|
32
|
+
if operator_config.spec.model == AUTO_SELECT_SERIES and hasattr(
|
33
|
+
operator_config.spec, "meta_features"
|
34
|
+
):
|
35
|
+
# For AUTO_SELECT_SERIES, handle each series with its specific model
|
36
|
+
meta_features = operator_config.spec.meta_features
|
37
|
+
results = ForecastResults()
|
38
|
+
sub_results_list = []
|
39
|
+
|
40
|
+
# Group the data by selected model
|
41
|
+
for model_name in meta_features["selected_model"].unique():
|
42
|
+
# Get series that use this model
|
43
|
+
series_groups = meta_features[meta_features["selected_model"] == model_name]
|
44
|
+
|
45
|
+
# Create a sub-config for this model
|
46
|
+
sub_config = copy.deepcopy(operator_config)
|
47
|
+
sub_config.spec.model = model_name
|
48
|
+
|
49
|
+
# Create sub-datasets for these series
|
50
|
+
sub_datasets = ForecastDatasets(
|
51
|
+
operator_config,
|
52
|
+
subset=series_groups[operator_config.spec.target_category_columns]
|
53
|
+
.values.flatten()
|
54
|
+
.tolist(),
|
55
|
+
)
|
56
|
+
|
57
|
+
# Get and run the appropriate model
|
58
|
+
sub_model = ForecastOperatorModelFactory.get_model(sub_config, sub_datasets)
|
59
|
+
sub_result_df, sub_elapsed_time = sub_model.build_model()
|
60
|
+
sub_results = sub_model.generate_report(
|
61
|
+
result_df=sub_result_df,
|
62
|
+
elapsed_time=sub_elapsed_time,
|
63
|
+
save_sub_reports=True,
|
64
|
+
)
|
65
|
+
sub_results_list.append(sub_results)
|
66
|
+
|
67
|
+
# results_df = pd.concat([results_df, sub_result_df], ignore_index=True, axis=0)
|
68
|
+
# elapsed_time += sub_elapsed_time
|
69
|
+
# Merge all sub_results into a single ForecastResults object
|
70
|
+
if sub_results_list:
|
71
|
+
results = sub_results_list[0]
|
72
|
+
for sub_result in sub_results_list[1:]:
|
73
|
+
results.merge(sub_result)
|
74
|
+
else:
|
75
|
+
results = None
|
76
|
+
|
77
|
+
else:
|
78
|
+
# For other cases, use the single selected model
|
79
|
+
results = model.generate_report()
|
30
80
|
# saving to model catalog
|
31
81
|
spec = operator_config.spec
|
32
82
|
if spec.what_if_analysis and datasets.additional_data:
|
@@ -89,4 +89,6 @@ SUMMARY_METRICS_HORIZON_LIMIT = 10
|
|
89
89
|
PROPHET_INTERNAL_DATE_COL = "ds"
|
90
90
|
RENDER_LIMIT = 5000
|
91
91
|
AUTO_SELECT = "auto-select"
|
92
|
+
AUTO_SELECT_SERIES = "auto-select-series"
|
92
93
|
BACKTEST_REPORT_NAME = "back_test.csv"
|
94
|
+
TROUBLESHOOTING_GUIDE = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-operators/troubleshooting.md"
|
@@ -4,6 +4,9 @@
|
|
4
4
|
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
|
+
from ads.opctl.operator.lowcode.forecast.const import TROUBLESHOOTING_GUIDE
|
8
|
+
|
9
|
+
|
7
10
|
class ForecastSchemaYamlError(Exception):
|
8
11
|
"""Exception raised when there is an issue with the schema."""
|
9
12
|
|
@@ -12,6 +15,7 @@ class ForecastSchemaYamlError(Exception):
|
|
12
15
|
"Invalid forecast operator specification. Check the YAML structure and ensure it "
|
13
16
|
"complies with the required schema for forecast operator. \n"
|
14
17
|
f"{error}"
|
18
|
+
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps."
|
15
19
|
)
|
16
20
|
|
17
21
|
|
@@ -23,4 +27,5 @@ class ForecastInputDataError(Exception):
|
|
23
27
|
"Invalid input data. Check the input data and ensure it "
|
24
28
|
"complies with the validation criteria. \n"
|
25
29
|
f"{error}"
|
30
|
+
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps."
|
26
31
|
)
|
@@ -0,0 +1,310 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
|
9
|
+
|
10
|
+
class MetaSelector:
|
11
|
+
"""
|
12
|
+
A class to select the best forecasting model for each series based on pre-learned meta-rules.
|
13
|
+
The rules are based on the meta-features calculated by the FFORMS approach.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self):
|
17
|
+
"""Initialize the MetaSelector with pre-learned meta rules"""
|
18
|
+
# Pre-learned rules based on meta-features
|
19
|
+
self._meta_rules = {
|
20
|
+
# Rule 1: Strong trend, weak seasonality → ARIMA
|
21
|
+
"arima_0": {
|
22
|
+
"conditions": [
|
23
|
+
("ts_trend", "abs>=", 0.65), # Strong trend strength
|
24
|
+
("ts_seasonal_strength", "<", 0.20), # Weak seasonality
|
25
|
+
],
|
26
|
+
"model": "arima",
|
27
|
+
"priority": 1,
|
28
|
+
},
|
29
|
+
# Rule 2: Strong seasonality, long series → Prophet
|
30
|
+
"prophet_0": {
|
31
|
+
"conditions": [
|
32
|
+
("ts_seasonal_strength", ">=", 0.50), # Strong seasonality
|
33
|
+
("ts_n_obs", ">=", 200), # Long series
|
34
|
+
],
|
35
|
+
"model": "prophet",
|
36
|
+
"priority": 2,
|
37
|
+
},
|
38
|
+
# Rule 3: High entropy, low autocorrelation → AutoMLX
|
39
|
+
"automlx_0": {
|
40
|
+
"conditions": [
|
41
|
+
("ts_entropy", ">=", 4.0), # High entropy
|
42
|
+
("ts_acf1", "<=", 0.30), # Low autocorrelation
|
43
|
+
],
|
44
|
+
"model": "automlx",
|
45
|
+
"priority": 3,
|
46
|
+
},
|
47
|
+
# Rule 4: Strong seasonality with trend and changing patterns → Prophet
|
48
|
+
"prophet_1": {
|
49
|
+
"conditions": [
|
50
|
+
("ts_seasonal_strength", ">=", 0.3), # Strong seasonality
|
51
|
+
("ts_trend", "abs>=", 0.1), # Clear trend
|
52
|
+
("ts_turning_points_rate", ">=", 0.2), # Multiple change points
|
53
|
+
("ts_n_obs", ">=", 50), # Sufficient data
|
54
|
+
("ts_step_max", ">=", 100), # Significant steps
|
55
|
+
("ts_diff1_variance", ">=", 10), # Variable differences
|
56
|
+
],
|
57
|
+
"model": "prophet",
|
58
|
+
"priority": 4,
|
59
|
+
},
|
60
|
+
# Rule 5: Multiple seasonality with nonlinear patterns → Prophet
|
61
|
+
"prophet_2": {
|
62
|
+
"conditions": [
|
63
|
+
("ts_seasonal_peak_strength", ">=", 0.4), # Strong peak seasonality
|
64
|
+
("ts_seasonal_strength", ">=", 0.2), # Overall seasonality
|
65
|
+
("ts_acf10", ">=", 0.2), # Long-term correlation
|
66
|
+
("ts_entropy", ">=", 0.5), # Complex patterns
|
67
|
+
("ts_crossing_rate", ">=", 0.3), # Frequent mean crossings
|
68
|
+
],
|
69
|
+
"model": "prophet",
|
70
|
+
"priority": 5,
|
71
|
+
},
|
72
|
+
# Rule 6: Strong autocorrelation with stationary behavior → ARIMA
|
73
|
+
"arima_1": {
|
74
|
+
"conditions": [
|
75
|
+
("ts_acf1", ">=", 0.7), # Strong lag-1 correlation
|
76
|
+
("ts_acf2", ">=", 0.5), # Strong lag-2 correlation
|
77
|
+
("ts_seasonal_strength", "<", 0.3), # Weak seasonality
|
78
|
+
("ts_std_residuals", "<", 500), # Stable residuals
|
79
|
+
("ts_diff1_variance", "<", 100), # Stable first differences
|
80
|
+
("ts_hurst", ">", -0.1), # Some persistence
|
81
|
+
],
|
82
|
+
"model": "arima",
|
83
|
+
"priority": 6,
|
84
|
+
},
|
85
|
+
# Rule 7: Linear trend with moderate noise → ARIMA
|
86
|
+
"arima_2": {
|
87
|
+
"conditions": [
|
88
|
+
("ts_trend", "abs>=", 0.15), # Clear trend
|
89
|
+
("ts_trend_change", "<", 100), # Stable trend
|
90
|
+
("ts_cv", "<", 0.4), # Low variation
|
91
|
+
("ts_kurtosis", "<", 5), # Normal-like distribution
|
92
|
+
("ts_nonlinearity", "<", 1e5), # Linear relationships
|
93
|
+
],
|
94
|
+
"model": "arima",
|
95
|
+
"priority": 7,
|
96
|
+
},
|
97
|
+
# Rule 8: Complex seasonality with high nonlinearity → NeuralProphet
|
98
|
+
"neuralprophet_1": {
|
99
|
+
"conditions": [
|
100
|
+
("ts_seasonal_peak_strength", ">=", 0.5), # Strong seasonal peaks
|
101
|
+
("ts_nonlinearity", ">=", 1e6), # Nonlinear patterns
|
102
|
+
("ts_n_obs", ">=", 200), # Long series
|
103
|
+
("ts_entropy", ">=", 0.6), # Complex patterns
|
104
|
+
("ts_diff2_variance", ">=", 50), # Variable acceleration
|
105
|
+
],
|
106
|
+
"model": "neuralprophet",
|
107
|
+
"priority": 8,
|
108
|
+
},
|
109
|
+
# Rule 9: Multiple seasonal patterns with changing behavior → NeuralProphet
|
110
|
+
"neuralprophet_2": {
|
111
|
+
"conditions": [
|
112
|
+
("ts_seasonal_strength", ">=", 0.4), # Strong seasonality
|
113
|
+
("ts_turning_points_rate", ">=", 0.3), # Many turning points
|
114
|
+
("ts_skewness", "abs>=", 1), # Skewed distribution
|
115
|
+
("ts_diff1_mean", ">=", 10), # Large changes
|
116
|
+
("ts_crossing_rate", ">=", 0.4), # Frequent crossings
|
117
|
+
],
|
118
|
+
"model": "neuralprophet",
|
119
|
+
"priority": 9,
|
120
|
+
},
|
121
|
+
# Rule 10: High volatility with complex patterns → AutoMLX
|
122
|
+
"automlx_1": {
|
123
|
+
"conditions": [
|
124
|
+
("ts_cv", ">=", 0.6), # High variation
|
125
|
+
("ts_nonlinearity", ">=", 1e7), # Strong nonlinearity
|
126
|
+
("ts_spikes_rate", ">=", 0.1), # Frequent spikes
|
127
|
+
("ts_entropy", ">=", 0.7), # Very complex
|
128
|
+
("ts_std_residuals", ">=", 1000), # Large residuals
|
129
|
+
],
|
130
|
+
"model": "automlx",
|
131
|
+
"priority": 10,
|
132
|
+
},
|
133
|
+
# Rule 11: Unstable patterns with regime changes → AutoMLX
|
134
|
+
"automlx_2": {
|
135
|
+
"conditions": [
|
136
|
+
("ts_trend_change", ">=", 200), # Changing trend
|
137
|
+
("ts_turning_points_rate", ">=", 0.4), # Many turning points
|
138
|
+
("ts_diff2_variance", ">=", 100), # Variable acceleration
|
139
|
+
("ts_hurst", "<", -0.2), # Anti-persistent
|
140
|
+
("ts_step_max", ">=", 1000), # Large steps
|
141
|
+
],
|
142
|
+
"model": "automlx",
|
143
|
+
"priority": 11,
|
144
|
+
},
|
145
|
+
# Rule 12: Long series with stable seasonality → AutoTS
|
146
|
+
"autots_1": {
|
147
|
+
"conditions": [
|
148
|
+
("ts_n_obs", ">=", 150), # Long series
|
149
|
+
("ts_seasonal_strength", ">=", 0.2), # Moderate seasonality
|
150
|
+
("ts_cv", "<", 0.5), # Moderate variation
|
151
|
+
("ts_entropy", "<", 0.5), # Not too complex
|
152
|
+
("ts_acf1", ">=", 0.3), # Some autocorrelation
|
153
|
+
],
|
154
|
+
"model": "autots",
|
155
|
+
"priority": 12,
|
156
|
+
},
|
157
|
+
# Rule 13: Stable patterns with low noise → Prophet
|
158
|
+
"prophet_3": {
|
159
|
+
"conditions": [
|
160
|
+
("ts_cv", "<", 0.3), # Low variation
|
161
|
+
("ts_kurtosis", "<", 4), # Normal-like
|
162
|
+
("ts_turning_points_rate", "<", 0.25), # Few turning points
|
163
|
+
("ts_diff1_variance", "<", 50), # Stable changes
|
164
|
+
("ts_seasonal_strength", ">=", 0.1), # Some seasonality
|
165
|
+
],
|
166
|
+
"model": "prophet",
|
167
|
+
"priority": 13,
|
168
|
+
},
|
169
|
+
# Rule 14: Short series with strong linear patterns → ARIMA
|
170
|
+
"arima_3": {
|
171
|
+
"conditions": [
|
172
|
+
("ts_n_obs", "<", 100), # Short series
|
173
|
+
("ts_trend", "abs>=", 0.2), # Strong trend
|
174
|
+
("ts_entropy", "<", 0.4), # Simple patterns
|
175
|
+
("ts_nonlinearity", "<", 1e5), # Linear
|
176
|
+
("ts_seasonal_strength", "<", 0.2), # Weak seasonality
|
177
|
+
],
|
178
|
+
"model": "arima",
|
179
|
+
"priority": 14,
|
180
|
+
},
|
181
|
+
# Rule 15: Complex seasonal patterns with long memory → NeuralProphet
|
182
|
+
"neuralprophet_3": {
|
183
|
+
"conditions": [
|
184
|
+
("ts_n_obs", ">=", 300), # Very long series
|
185
|
+
("ts_seasonal_strength", ">=", 0.3), # Clear seasonality
|
186
|
+
("ts_acf10", ">=", 0.3), # Long memory
|
187
|
+
("ts_hurst", ">", 0), # Persistent
|
188
|
+
("ts_nonlinearity", ">=", 5e5), # Some nonlinearity
|
189
|
+
],
|
190
|
+
"model": "neuralprophet",
|
191
|
+
"priority": 15,
|
192
|
+
},
|
193
|
+
# Rule 16: High complexity with non-normal distribution → AutoMLX
|
194
|
+
"automlx_3": {
|
195
|
+
"conditions": [
|
196
|
+
("ts_kurtosis", ">=", 5), # Heavy tails
|
197
|
+
("ts_skewness", "abs>=", 2), # Highly skewed
|
198
|
+
("ts_entropy", ">=", 0.6), # Complex
|
199
|
+
("ts_spikes_rate", ">=", 0.05), # Some spikes
|
200
|
+
("ts_diff2_mean", ">=", 5), # Changing acceleration
|
201
|
+
],
|
202
|
+
"model": "automlx",
|
203
|
+
"priority": 16,
|
204
|
+
},
|
205
|
+
# Rule 17: Simple patterns with weak seasonality → AutoTS
|
206
|
+
"autots_2": {
|
207
|
+
"conditions": [
|
208
|
+
("ts_entropy", "<", 0.3), # Simple patterns
|
209
|
+
("ts_seasonal_strength", "<", 0.3), # Weak seasonality
|
210
|
+
("ts_cv", "<", 0.4), # Low variation
|
211
|
+
("ts_nonlinearity", "<", 1e5), # Nearly linear
|
212
|
+
("ts_diff1_mean", "<", 10), # Small changes
|
213
|
+
],
|
214
|
+
"model": "autots",
|
215
|
+
"priority": 17,
|
216
|
+
},
|
217
|
+
}
|
218
|
+
|
219
|
+
def _evaluate_condition(self, value, operator, threshold):
|
220
|
+
"""Evaluate a single condition based on pre-defined operators"""
|
221
|
+
if pd.isna(value):
|
222
|
+
return False
|
223
|
+
|
224
|
+
if operator == ">=":
|
225
|
+
return value >= threshold
|
226
|
+
elif operator == ">":
|
227
|
+
return value > threshold
|
228
|
+
elif operator == "<":
|
229
|
+
return value < threshold
|
230
|
+
elif operator == "<=":
|
231
|
+
return value <= threshold
|
232
|
+
elif operator == "abs>=":
|
233
|
+
return abs(value) >= threshold
|
234
|
+
elif operator == "abs<":
|
235
|
+
return abs(value) < threshold
|
236
|
+
return False
|
237
|
+
|
238
|
+
def _check_model_conditions(self, features, model_rules):
|
239
|
+
"""Check if a series meets all conditions for a model"""
|
240
|
+
for feature, operator, threshold in model_rules["conditions"]:
|
241
|
+
if feature not in features:
|
242
|
+
return False
|
243
|
+
if not self._evaluate_condition(features[feature], operator, threshold):
|
244
|
+
return False
|
245
|
+
return True
|
246
|
+
|
247
|
+
def select_best_model(self, meta_features_df):
|
248
|
+
"""
|
249
|
+
Select the best model for each series based on pre-learned rules.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
meta_features_df : pandas.DataFrame
|
254
|
+
DataFrame containing meta-features for each series, as returned by
|
255
|
+
build_fforms_meta_features
|
256
|
+
|
257
|
+
Returns
|
258
|
+
-------
|
259
|
+
pandas.DataFrame
|
260
|
+
DataFrame with series identifiers, selected model names, and matching rule info
|
261
|
+
"""
|
262
|
+
results = []
|
263
|
+
|
264
|
+
# Process each series
|
265
|
+
for _, row in meta_features_df.iterrows():
|
266
|
+
series_info = {}
|
267
|
+
|
268
|
+
# Preserve group columns if they exist
|
269
|
+
group_cols = [col for col in row.index if not col.startswith("ts_")]
|
270
|
+
for col in group_cols:
|
271
|
+
series_info[col] = row[col]
|
272
|
+
|
273
|
+
# Find matching models
|
274
|
+
matching_models = []
|
275
|
+
matched_features = {}
|
276
|
+
for rule_name, rules in self._meta_rules.items():
|
277
|
+
if self._check_model_conditions(row, rules):
|
278
|
+
matching_models.append((rule_name, rules["priority"]))
|
279
|
+
# Store which features triggered this rule
|
280
|
+
matched_features[rule_name] = [
|
281
|
+
(feature, row[feature]) for feature, _, _ in rules["conditions"]
|
282
|
+
]
|
283
|
+
|
284
|
+
# Select best model based on priority
|
285
|
+
if matching_models:
|
286
|
+
best_rule = min(matching_models, key=lambda x: x[1])[0]
|
287
|
+
best_model = self._meta_rules[best_rule]["model"]
|
288
|
+
series_info["matched_features"] = matched_features[best_rule]
|
289
|
+
else:
|
290
|
+
best_rule = "default"
|
291
|
+
best_model = "prophet" # Default to prophet if no rules match
|
292
|
+
series_info["matched_features"] = []
|
293
|
+
|
294
|
+
series_info["selected_model"] = best_model
|
295
|
+
series_info["rule_matched"] = best_rule
|
296
|
+
results.append(series_info)
|
297
|
+
|
298
|
+
return pd.DataFrame(results)
|
299
|
+
|
300
|
+
def get_model_conditions(self):
|
301
|
+
"""
|
302
|
+
Get the pre-learned conditions for each model.
|
303
|
+
This is read-only and cannot be modified at runtime.
|
304
|
+
|
305
|
+
Returns
|
306
|
+
-------
|
307
|
+
dict
|
308
|
+
Dictionary containing the conditions for each model
|
309
|
+
"""
|
310
|
+
return self._meta_rules.copy()
|
@@ -158,7 +158,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
158
158
|
summary_frame = model.forecast(
|
159
159
|
X=X_pred,
|
160
160
|
periods=horizon,
|
161
|
-
alpha=1 -
|
161
|
+
alpha=1 - self.spec.confidence_interval_width,
|
162
162
|
)
|
163
163
|
|
164
164
|
fitted_values = model.predict(data_i.drop(target, axis=1))[
|