oracle-ads 2.12.8__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +4 -4
  2. ads/aqua/app.py +12 -2
  3. ads/aqua/common/enums.py +3 -0
  4. ads/aqua/common/utils.py +62 -2
  5. ads/aqua/data.py +2 -19
  6. ads/aqua/evaluation/entities.py +6 -0
  7. ads/aqua/evaluation/evaluation.py +25 -3
  8. ads/aqua/extension/deployment_handler.py +8 -4
  9. ads/aqua/extension/finetune_handler.py +8 -14
  10. ads/aqua/extension/model_handler.py +25 -6
  11. ads/aqua/extension/ui_handler.py +13 -1
  12. ads/aqua/finetuning/constants.py +5 -2
  13. ads/aqua/finetuning/entities.py +70 -17
  14. ads/aqua/finetuning/finetuning.py +79 -82
  15. ads/aqua/model/entities.py +4 -1
  16. ads/aqua/model/model.py +95 -29
  17. ads/aqua/modeldeployment/deployment.py +13 -1
  18. ads/aqua/modeldeployment/entities.py +7 -4
  19. ads/aqua/ui.py +24 -2
  20. ads/common/auth.py +9 -9
  21. ads/llm/autogen/__init__.py +2 -0
  22. ads/llm/autogen/constants.py +15 -0
  23. ads/llm/autogen/reports/__init__.py +2 -0
  24. ads/llm/autogen/reports/base.py +67 -0
  25. ads/llm/autogen/reports/data.py +103 -0
  26. ads/llm/autogen/reports/session.py +526 -0
  27. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  28. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  29. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  30. ads/llm/autogen/reports/utils.py +56 -0
  31. ads/llm/autogen/v02/__init__.py +4 -0
  32. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  33. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  34. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  35. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  36. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  37. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  38. ads/llm/autogen/v02/loggers/utils.py +86 -0
  39. ads/llm/autogen/v02/runtime_logging.py +163 -0
  40. ads/llm/guardrails/base.py +6 -5
  41. ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
  42. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
  43. ads/model/__init__.py +11 -13
  44. ads/model/artifact.py +47 -8
  45. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  46. ads/model/framework/embedding_onnx_model.py +438 -0
  47. ads/model/generic_model.py +26 -24
  48. ads/model/model_metadata.py +8 -7
  49. ads/opctl/config/merger.py +13 -14
  50. ads/opctl/operator/common/operator_config.py +4 -4
  51. ads/opctl/operator/lowcode/common/transformations.py +12 -5
  52. ads/opctl/operator/lowcode/common/utils.py +11 -5
  53. ads/opctl/operator/lowcode/forecast/const.py +3 -0
  54. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  55. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  56. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  57. ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
  58. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  59. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  60. ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
  61. ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
  62. ads/opctl/operator/lowcode/forecast/utils.py +8 -6
  63. ads/telemetry/base.py +18 -11
  64. ads/telemetry/client.py +33 -13
  65. ads/templates/schemas/openapi.json +1740 -0
  66. ads/templates/score_embedding_onnx.jinja2 +202 -0
  67. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +9 -10
  68. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +71 -50
  69. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from ads.opctl.operator.lowcode.common.utils import (
28
28
  seconds_to_datetime,
29
29
  write_data,
30
30
  )
31
+ from ads.opctl.operator.lowcode.common.const import DataColumns
31
32
  from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
32
33
  from ads.opctl.operator.lowcode.forecast.utils import (
33
34
  _build_metrics_df,
@@ -43,10 +44,12 @@ from ads.opctl.operator.lowcode.forecast.utils import (
43
44
 
44
45
  from ..const import (
45
46
  AUTO_SELECT,
47
+ BACKTEST_REPORT_NAME,
46
48
  SUMMARY_METRICS_HORIZON_LIMIT,
47
49
  SpeedAccuracyMode,
48
50
  SupportedMetrics,
49
51
  SupportedModels,
52
+ BACKTEST_REPORT_NAME,
50
53
  )
51
54
  from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
52
55
  from .forecast_datasets import ForecastDatasets
@@ -68,7 +71,7 @@ class ForecastOperatorBaseModel(ABC):
68
71
  self.config: ForecastOperatorConfig = config
69
72
  self.spec: ForecastOperatorSpec = config.spec
70
73
  self.datasets: ForecastDatasets = datasets
71
-
74
+ self.target_cat_col = self.spec.target_category_columns
72
75
  self.full_data_dict = datasets.get_data_by_series()
73
76
 
74
77
  self.test_eval_metrics = None
@@ -123,6 +126,9 @@ class ForecastOperatorBaseModel(ABC):
123
126
 
124
127
  if self.spec.generate_report or self.spec.generate_metrics:
125
128
  self.eval_metrics = self.generate_train_metrics()
129
+ if not self.target_cat_col:
130
+ self.eval_metrics.rename({"Series 1": self.original_target_column},
131
+ axis=1, inplace=True)
126
132
 
127
133
  if self.spec.test_data:
128
134
  try:
@@ -133,6 +139,9 @@ class ForecastOperatorBaseModel(ABC):
133
139
  ) = self._test_evaluate_metrics(
134
140
  elapsed_time=elapsed_time,
135
141
  )
142
+ if not self.target_cat_col:
143
+ self.test_eval_metrics.rename({"Series 1": self.original_target_column},
144
+ axis=1, inplace=True)
136
145
  except Exception:
137
146
  logger.warn("Unable to generate Test Metrics.")
138
147
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -178,7 +187,7 @@ class ForecastOperatorBaseModel(ABC):
178
187
  first_5_rows_blocks = [
179
188
  rc.DataTable(
180
189
  df.head(5),
181
- label=s_id,
190
+ label=s_id if self.target_cat_col else None,
182
191
  index=True,
183
192
  )
184
193
  for s_id, df in self.full_data_dict.items()
@@ -187,7 +196,7 @@ class ForecastOperatorBaseModel(ABC):
187
196
  last_5_rows_blocks = [
188
197
  rc.DataTable(
189
198
  df.tail(5),
190
- label=s_id,
199
+ label=s_id if self.target_cat_col else None,
191
200
  index=True,
192
201
  )
193
202
  for s_id, df in self.full_data_dict.items()
@@ -196,7 +205,7 @@ class ForecastOperatorBaseModel(ABC):
196
205
  data_summary_blocks = [
197
206
  rc.DataTable(
198
207
  df.describe(),
199
- label=s_id,
208
+ label=s_id if self.target_cat_col else None,
200
209
  index=True,
201
210
  )
202
211
  for s_id, df in self.full_data_dict.items()
@@ -214,17 +223,17 @@ class ForecastOperatorBaseModel(ABC):
214
223
  rc.Block(
215
224
  first_10_title,
216
225
  # series_subtext,
217
- rc.Select(blocks=first_5_rows_blocks),
226
+ rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
218
227
  ),
219
228
  rc.Block(
220
229
  last_10_title,
221
230
  # series_subtext,
222
- rc.Select(blocks=last_5_rows_blocks),
231
+ rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
223
232
  ),
224
233
  rc.Block(
225
234
  summary_title,
226
235
  # series_subtext,
227
- rc.Select(blocks=data_summary_blocks),
236
+ rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
228
237
  ),
229
238
  rc.Separator(),
230
239
  )
@@ -256,11 +265,12 @@ class ForecastOperatorBaseModel(ABC):
256
265
 
257
266
  backtest_sections = []
258
267
  output_dir = self.spec.output_directory.url
259
- backtest_report_name = "backtest_stats.csv"
260
- file_path = f"{output_dir}/{backtest_report_name}"
268
+ file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261
269
  if self.spec.model == AUTO_SELECT:
262
270
  backtest_sections.append(
263
- rc.Heading("Auto-select statistics", level=2)
271
+ rc.Heading(
272
+ "Auto-Select Backtesting and Performance Metrics", level=2
273
+ )
264
274
  )
265
275
  if not os.path.exists(file_path):
266
276
  failure_msg = rc.Text(
@@ -270,18 +280,22 @@ class ForecastOperatorBaseModel(ABC):
270
280
  backtest_sections.append(failure_msg)
271
281
  else:
272
282
  backtest_stats = pd.read_csv(file_path)
273
- average_dict = backtest_stats.mean().to_dict()
274
- del average_dict["backtest"]
283
+ model_metric_map = backtest_stats.drop(
284
+ columns=["metric", "backtest"]
285
+ )
286
+ average_dict = {
287
+ k: round(v, 4)
288
+ for k, v in model_metric_map.mean().to_dict().items()
289
+ }
275
290
  best_model = min(average_dict, key=average_dict.get)
276
- backtest_text = rc.Heading("Back Testing Metrics", level=3)
277
291
  summary_text = rc.Text(
278
- f"Overall, the average scores for the models are {average_dict}, with {best_model}"
279
- f" being identified as the top-performing model during backtesting."
292
+ f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
293
+ f" {best_model} being identified as the top-performing model during backtesting."
280
294
  )
281
295
  backtest_table = rc.DataTable(backtest_stats, index=True)
282
296
  liner_plot = get_auto_select_plot(backtest_stats)
283
297
  backtest_sections.extend(
284
- [backtest_text, backtest_table, summary_text, liner_plot]
298
+ [backtest_table, summary_text, liner_plot]
285
299
  )
286
300
 
287
301
  forecast_plots = []
@@ -294,6 +308,7 @@ class ForecastOperatorBaseModel(ABC):
294
308
  horizon=self.spec.horizon,
295
309
  test_data=test_data,
296
310
  ci_interval_width=self.spec.confidence_interval_width,
311
+ target_category_column=self.target_cat_col
297
312
  )
298
313
  if (
299
314
  series_name is not None
@@ -307,7 +322,14 @@ class ForecastOperatorBaseModel(ABC):
307
322
  forecast_plots = [forecast_text, forecast_sec]
308
323
 
309
324
  yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
310
- yaml_appendix = rc.Yaml(self.config.to_dict())
325
+ config_dict = self.config.to_dict()
326
+ # pop the data incase it isn't json serializable
327
+ config_dict["spec"]["historical_data"].pop("data")
328
+ if config_dict["spec"].get("additional_data"):
329
+ config_dict["spec"]["additional_data"].pop("data")
330
+ if config_dict["spec"].get("test_data"):
331
+ config_dict["spec"]["test_data"].pop("data")
332
+ yaml_appendix = rc.Yaml(config_dict)
311
333
  report_sections = (
312
334
  [summary]
313
335
  + backtest_sections
@@ -469,6 +491,7 @@ class ForecastOperatorBaseModel(ABC):
469
491
  f2.write(f1.read())
470
492
 
471
493
  # forecast csv report
494
+ result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
472
495
  write_data(
473
496
  data=result_df,
474
497
  filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
@@ -643,6 +666,13 @@ class ForecastOperatorBaseModel(ABC):
643
666
  storage_options=storage_options,
644
667
  )
645
668
 
669
+ def _validate_automlx_explanation_mode(self):
670
+ if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
671
+ raise ValueError(
672
+ "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673
+ "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
674
+ )
675
+
646
676
  @runtime_dependency(
647
677
  module="shap",
648
678
  err_msg=(
@@ -671,6 +701,9 @@ class ForecastOperatorBaseModel(ABC):
671
701
  )
672
702
  ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
673
703
 
704
+ # validate the automlx mode is use for automlx model
705
+ self._validate_automlx_explanation_mode()
706
+
674
707
  for s_id, data_i in self.datasets.get_data_by_series(
675
708
  include_horizon=False
676
709
  ).items():
@@ -705,6 +738,14 @@ class ForecastOperatorBaseModel(ABC):
705
738
  logger.warn(
706
739
  "No explanations generated. Ensure that additional data has been provided."
707
740
  )
741
+ elif (
742
+ self.spec.model == SupportedModels.AutoMLX
743
+ and self.spec.explanations_accuracy_mode
744
+ == SpeedAccuracyMode.AUTOMLX
745
+ ):
746
+ logger.warning(
747
+ "Global explanations not available for AutoMLX models with inherent explainability"
748
+ )
708
749
  else:
709
750
  self.global_explanation[s_id] = dict(
710
751
  zip(
@@ -360,7 +360,7 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
360
360
  pd.Series(
361
361
  m.state_dict(),
362
362
  index=m.state_dict().keys(),
363
- name=s_id,
363
+ name=s_id if self.target_cat_col else self.original_target_column,
364
364
  )
365
365
  )
366
366
  all_model_states = pd.concat(model_states, axis=1)
@@ -373,6 +373,13 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
373
373
  # If the key is present, call the "explain_model" method
374
374
  self.explain_model()
375
375
 
376
+ if not self.target_cat_col:
377
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
378
+ {"Series 1": self.original_target_column},
379
+ axis=1,
380
+ )
381
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
382
+
376
383
  # Create a markdown section for the global explainability
377
384
  global_explanation_section = rc.Block(
378
385
  rc.Heading("Global Explainability", level=2),
@@ -385,14 +392,14 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
385
392
  blocks = [
386
393
  rc.DataTable(
387
394
  local_ex_df.drop("Series", axis=1),
388
- label=s_id,
395
+ label=s_id if self.target_cat_col else None,
389
396
  index=True,
390
397
  )
391
398
  for s_id, local_ex_df in self.local_explanation.items()
392
399
  ]
393
400
  local_explanation_section = rc.Block(
394
401
  rc.Heading("Local Explanation of Models", level=2),
395
- rc.Select(blocks=blocks),
402
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
396
403
  )
397
404
 
398
405
  # Append the global explanation text and section to the "all_sections" list
@@ -256,6 +256,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
256
256
  self.outputs[s_id], include_legend=True
257
257
  ),
258
258
  series_ids=series_ids,
259
+ target_category_column=self.target_cat_col
259
260
  )
260
261
  section_1 = rc.Block(
261
262
  rc.Heading("Forecast Overview", level=2),
@@ -268,6 +269,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
268
269
  sec2 = _select_plot_list(
269
270
  lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
270
271
  series_ids=series_ids,
272
+ target_category_column=self.target_cat_col
271
273
  )
272
274
  section_2 = rc.Block(
273
275
  rc.Heading("Forecast Broken Down by Trend Component", level=2), sec2
@@ -281,7 +283,9 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
281
283
  sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
282
284
  )
283
285
  sec3 = _select_plot_list(
284
- lambda s_id: sec3_figs[s_id], series_ids=series_ids
286
+ lambda s_id: sec3_figs[s_id],
287
+ series_ids=series_ids,
288
+ target_category_column=self.target_cat_col
285
289
  )
286
290
  section_3 = rc.Block(rc.Heading("Forecast Changepoints", level=2), sec3)
287
291
 
@@ -295,7 +299,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
295
299
  pd.Series(
296
300
  m.seasonalities,
297
301
  index=pd.Index(m.seasonalities.keys(), dtype="object"),
298
- name=s_id,
302
+ name=s_id if self.target_cat_col else self.original_target_column,
299
303
  dtype="object",
300
304
  )
301
305
  )
@@ -316,15 +320,6 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
316
320
  global_explanation_df / global_explanation_df.sum(axis=0) * 100
317
321
  )
318
322
 
319
- # Create a markdown section for the global explainability
320
- global_explanation_section = rc.Block(
321
- rc.Heading("Global Explanation of Models", level=2),
322
- rc.Text(
323
- "The following tables provide the feature attribution for the global explainability."
324
- ),
325
- rc.DataTable(self.formatted_global_explanation, index=True),
326
- )
327
-
328
323
  aggregate_local_explanations = pd.DataFrame()
329
324
  for s_id, local_ex_df in self.local_explanation.items():
330
325
  local_ex_df_copy = local_ex_df.copy()
@@ -334,17 +329,33 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
334
329
  )
335
330
  self.formatted_local_explanation = aggregate_local_explanations
336
331
 
332
+ if not self.target_cat_col:
333
+ self.formatted_global_explanation = self.formatted_global_explanation.rename(
334
+ {"Series 1": self.original_target_column},
335
+ axis=1,
336
+ )
337
+ self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
338
+
339
+ # Create a markdown section for the global explainability
340
+ global_explanation_section = rc.Block(
341
+ rc.Heading("Global Explanation of Models", level=2),
342
+ rc.Text(
343
+ "The following tables provide the feature attribution for the global explainability."
344
+ ),
345
+ rc.DataTable(self.formatted_global_explanation, index=True),
346
+ )
347
+
337
348
  blocks = [
338
349
  rc.DataTable(
339
350
  local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
340
- label=s_id,
351
+ label=s_id if self.target_cat_col else None,
341
352
  index=True,
342
353
  )
343
354
  for s_id, local_ex_df in self.local_explanation.items()
344
355
  ]
345
356
  local_explanation_section = rc.Block(
346
357
  rc.Heading("Local Explanation of Models", level=2),
347
- rc.Select(blocks=blocks),
358
+ rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
348
359
  )
349
360
 
350
361
  # Append the global explanation text and section to the "all_sections" list
@@ -358,11 +369,7 @@ class ProphetOperatorModel(ForecastOperatorBaseModel):
358
369
  logger.debug(f"Full Traceback: {traceback.format_exc()}")
359
370
 
360
371
  model_description = rc.Text(
361
- "Prophet is a procedure for forecasting time series data based on an additive "
362
- "model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
363
- "plus holiday effects. It works best with time series that have strong seasonal "
364
- "effects and several seasons of historical data. Prophet is robust to missing "
365
- "data and shifts in the trend, and typically handles outliers well."
372
+ """Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
366
373
  )
367
374
  other_sections = all_sections
368
375
 
@@ -10,6 +10,7 @@ from pathlib import Path
10
10
 
11
11
  from ads.opctl import logger
12
12
  from ads.opctl.operator.lowcode.common.const import DataColumns
13
+ from ads.opctl.operator.lowcode.forecast.const import BACKTEST_REPORT_NAME
13
14
  from .model.forecast_datasets import ForecastDatasets
14
15
  from .operator_config import ForecastOperatorConfig
15
16
  from ads.opctl.operator.lowcode.forecast.model.factory import SupportedModels
@@ -156,8 +157,8 @@ class ModelEvaluator:
156
157
  best_model = min(avg_backtests_metric, key=avg_backtests_metric.get)
157
158
  logger.info(f"Among models {self.models}, {best_model} model shows better performance during backtesting.")
158
159
  backtest_stats = pd.DataFrame(nonempty_metrics).rename_axis('backtest')
160
+ backtest_stats["metric"] = operator_config.spec.metric
159
161
  backtest_stats.reset_index(inplace=True)
160
162
  output_dir = operator_config.spec.output_directory.url
161
- backtest_report_name = "backtest_stats.csv"
162
- backtest_stats.to_csv(f"{output_dir}/{backtest_report_name}", index=False)
163
+ backtest_stats.to_csv(f"{output_dir}/{BACKTEST_REPORT_NAME}", index=False)
163
164
  return best_model
@@ -37,6 +37,9 @@ spec:
37
37
  nullable: true
38
38
  required: false
39
39
  type: dict
40
+ data:
41
+ nullable: true
42
+ required: false
40
43
  format:
41
44
  allowed:
42
45
  - csv
@@ -48,6 +51,7 @@ spec:
48
51
  - sql_query
49
52
  - hdf
50
53
  - tsv
54
+ - pandas
51
55
  required: false
52
56
  type: string
53
57
  columns:
@@ -92,6 +96,9 @@ spec:
92
96
  nullable: true
93
97
  required: false
94
98
  type: dict
99
+ data:
100
+ nullable: true
101
+ required: false
95
102
  format:
96
103
  allowed:
97
104
  - csv
@@ -103,6 +110,7 @@ spec:
103
110
  - sql_query
104
111
  - hdf
105
112
  - tsv
113
+ - pandas
106
114
  required: false
107
115
  type: string
108
116
  columns:
@@ -146,6 +154,9 @@ spec:
146
154
  nullable: true
147
155
  required: false
148
156
  type: dict
157
+ data:
158
+ nullable: true
159
+ required: false
149
160
  format:
150
161
  allowed:
151
162
  - csv
@@ -157,6 +168,7 @@ spec:
157
168
  - sql_query
158
169
  - hdf
159
170
  - tsv
171
+ - pandas
160
172
  required: false
161
173
  type: string
162
174
  columns:
@@ -332,6 +344,7 @@ spec:
332
344
  - HIGH_ACCURACY
333
345
  - BALANCED
334
346
  - FAST_APPROXIMATE
347
+ - AUTOMLX
335
348
 
336
349
  generate_report:
337
350
  type: boolean
@@ -250,8 +250,8 @@ def evaluate_train_metrics(output):
250
250
  return total_metrics
251
251
 
252
252
 
253
- def _select_plot_list(fn, series_ids):
254
- blocks = [rc.Widget(fn(s_id=s_id), label=s_id) for s_id in series_ids]
253
+ def _select_plot_list(fn, series_ids, target_category_column):
254
+ blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
255
255
  return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
256
256
 
257
257
 
@@ -261,10 +261,11 @@ def _add_unit(num, unit):
261
261
 
262
262
  def get_auto_select_plot(backtest_results):
263
263
  fig = go.Figure()
264
- columns = backtest_results.columns.tolist()
264
+ back_test_csv_columns = backtest_results.columns.tolist()
265
265
  back_test_column = "backtest"
266
- columns.remove(back_test_column)
267
- for column in columns:
266
+ metric_column = "metric"
267
+ models = [x for x in back_test_csv_columns if x not in [back_test_column, metric_column]]
268
+ for i, column in enumerate(models):
268
269
  fig.add_trace(
269
270
  go.Scatter(
270
271
  x=backtest_results[back_test_column],
@@ -282,6 +283,7 @@ def get_forecast_plots(
282
283
  horizon,
283
284
  test_data=None,
284
285
  ci_interval_width=0.95,
286
+ target_category_column=None
285
287
  ):
286
288
  def plot_forecast_plotly(s_id):
287
289
  fig = go.Figure()
@@ -378,7 +380,7 @@ def get_forecast_plots(
378
380
  )
379
381
  return fig
380
382
 
381
- return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids())
383
+ return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
382
384
 
383
385
 
384
386
  def convert_target(target: str, target_col: str):
ads/telemetry/base.py CHANGED
@@ -1,17 +1,18 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
  import logging
7
6
 
8
- from ads import set_auth
7
+ import oci
8
+
9
9
  from ads.common import oci_client as oc
10
- from ads.common.auth import default_signer
10
+ from ads.common.auth import default_signer, resource_principal
11
11
  from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
12
12
 
13
-
14
13
  logger = logging.getLogger(__name__)
14
+
15
+
15
16
  class TelemetryBase:
16
17
  """Base class for Telemetry Client."""
17
18
 
@@ -25,15 +26,21 @@ class TelemetryBase:
25
26
  namespace : str, optional
26
27
  Namespace of the OCI object storage bucket, by default None.
27
28
  """
29
+ # Use resource principal as authentication method if available,
30
+ # however, do not change the ADS authentication if user configured it by set_auth.
28
31
  if OCI_RESOURCE_PRINCIPAL_VERSION:
29
- set_auth("resource_principal")
30
- self._auth = default_signer()
31
- self.os_client = oc.OCIClientFactory(**self._auth).object_storage
32
+ self._auth = resource_principal()
33
+ else:
34
+ self._auth = default_signer()
35
+ self.os_client: oci.object_storage.ObjectStorageClient = oc.OCIClientFactory(
36
+ **self._auth
37
+ ).object_storage
32
38
  self.bucket = bucket
33
39
  self._namespace = namespace
34
40
  self._service_endpoint = None
35
- logger.debug(f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}")
36
-
41
+ logger.debug(
42
+ f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}"
43
+ )
37
44
 
38
45
  @property
39
46
  def namespace(self) -> str:
@@ -58,5 +65,5 @@ class TelemetryBase:
58
65
  Tenancy-specific endpoint.
59
66
  """
60
67
  if not self._service_endpoint:
61
- self._service_endpoint = self.os_client.base_client.endpoint
68
+ self._service_endpoint = str(self.os_client.base_client.endpoint)
62
69
  return self._service_endpoint
ads/telemetry/client.py CHANGED
@@ -1,17 +1,19 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
 
6
5
 
7
6
  import logging
8
7
  import threading
8
+ import traceback
9
9
  import urllib.parse
10
- import requests
11
- from requests import Response
12
- from .base import TelemetryBase
10
+ from typing import Optional
11
+
12
+ import oci
13
+
13
14
  from ads.config import DEBUG_TELEMETRY
14
15
 
16
+ from .base import TelemetryBase
15
17
 
16
18
  logger = logging.getLogger(__name__)
17
19
 
@@ -32,7 +34,7 @@ class TelemetryClient(TelemetryBase):
32
34
  >>> import traceback
33
35
  >>> from ads.telemetry.client import TelemetryClient
34
36
  >>> AQUA_BUCKET = os.environ.get("AQUA_BUCKET", "service-managed-models")
35
- >>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "ociodscdev")
37
+ >>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "namespace")
36
38
  >>> telemetry = TelemetryClient(bucket=AQUA_BUCKET, namespace=AQUA_BUCKET_NS)
37
39
  >>> telemetry.record_event_async(category="aqua/service/model", action="create") # records create action
38
40
  >>> telemetry.record_event_async(category="aqua/service/model/create", action="shape", detail="VM.GPU.A10.1")
@@ -45,7 +47,7 @@ class TelemetryClient(TelemetryBase):
45
47
 
46
48
  def record_event(
47
49
  self, category: str = None, action: str = None, detail: str = None, **kwargs
48
- ) -> Response:
50
+ ) -> Optional[int]:
49
51
  """Send a head request to generate an event record.
50
52
 
51
53
  Parameters
@@ -62,23 +64,41 @@ class TelemetryClient(TelemetryBase):
62
64
 
63
65
  Returns
64
66
  -------
65
- Response
67
+ int
68
+ The status code for the telemetry request.
69
+ 200: The the object exists for the telemetry request
70
+ 404: The the object does not exist for the telemetry request.
71
+ Note that for telemetry purpose, the object does not need to be exist.
72
+ `None` will be returned if the telemetry request failed.
66
73
  """
67
74
  try:
68
75
  if not category or not action:
69
76
  raise ValueError("Please specify the category and the action.")
70
77
  if detail:
71
78
  category, action = f"{category}/{action}", detail
79
+ # Here `endpoint`` is for debugging purpose
80
+ # For some federated/domain users, the `endpoint` may not be a valid URL
72
81
  endpoint = f"{self.service_endpoint}/n/{self.namespace}/b/{self.bucket}/o/telemetry/{category}/{action}"
73
- headers = {"User-Agent": self._encode_user_agent(**kwargs)}
74
82
  logger.debug(f"Sending telemetry to endpoint: {endpoint}")
75
- signer = self._auth["signer"]
76
- response = requests.head(endpoint, auth=signer, headers=headers)
77
- logger.debug(f"Telemetry status code: {response.status_code}")
78
- return response
83
+
84
+ self.os_client.base_client.user_agent = self._encode_user_agent(**kwargs)
85
+ try:
86
+ response: oci.response.Response = self.os_client.head_object(
87
+ namespace_name=self.namespace,
88
+ bucket_name=self.bucket,
89
+ object_name=f"telemetry/{category}/{action}",
90
+ )
91
+ logger.debug(f"Telemetry status: {response.status}")
92
+ return response.status
93
+ except oci.exceptions.ServiceError as ex:
94
+ if ex.status == 404:
95
+ return ex.status
96
+ raise ex
79
97
  except Exception as e:
80
98
  if DEBUG_TELEMETRY:
81
99
  logger.error(f"There is an error recording telemetry: {e}")
100
+ traceback.print_exc()
101
+ return None
82
102
 
83
103
  def record_event_async(
84
104
  self, category: str = None, action: str = None, detail: str = None, **kwargs