teradataml 20.0.0.0__py3-none-any.whl → 20.0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of teradataml might be problematic. Click here for more details.
- teradataml/LICENSE-3RD-PARTY.pdf +0 -0
- teradataml/LICENSE.pdf +0 -0
- teradataml/README.md +71 -0
- teradataml/_version.py +2 -2
- teradataml/analytics/analytic_function_executor.py +51 -24
- teradataml/analytics/json_parser/utils.py +11 -17
- teradataml/automl/__init__.py +103 -48
- teradataml/automl/data_preparation.py +55 -37
- teradataml/automl/data_transformation.py +131 -69
- teradataml/automl/feature_engineering.py +117 -185
- teradataml/automl/feature_exploration.py +9 -2
- teradataml/automl/model_evaluation.py +13 -25
- teradataml/automl/model_training.py +214 -75
- teradataml/catalog/model_cataloging_utils.py +1 -1
- teradataml/clients/auth_client.py +133 -0
- teradataml/common/aed_utils.py +3 -2
- teradataml/common/constants.py +11 -6
- teradataml/common/garbagecollector.py +5 -0
- teradataml/common/messagecodes.py +3 -1
- teradataml/common/messages.py +2 -1
- teradataml/common/utils.py +6 -0
- teradataml/context/context.py +49 -29
- teradataml/data/advertising.csv +201 -0
- teradataml/data/bank_marketing.csv +11163 -0
- teradataml/data/bike_sharing.csv +732 -0
- teradataml/data/boston2cols.csv +721 -0
- teradataml/data/breast_cancer.csv +570 -0
- teradataml/data/customer_segmentation_test.csv +2628 -0
- teradataml/data/customer_segmentation_train.csv +8069 -0
- teradataml/data/docs/sqle/docs_17_10/OneHotEncodingFit.py +3 -1
- teradataml/data/docs/sqle/docs_17_10/OneHotEncodingTransform.py +6 -0
- teradataml/data/docs/sqle/docs_17_10/OutlierFilterTransform.py +5 -1
- teradataml/data/docs/sqle/docs_17_20/ANOVA.py +61 -1
- teradataml/data/docs/sqle/docs_17_20/ColumnTransformer.py +2 -0
- teradataml/data/docs/sqle/docs_17_20/FTest.py +105 -26
- teradataml/data/docs/sqle/docs_17_20/GLM.py +162 -1
- teradataml/data/docs/sqle/docs_17_20/GetFutileColumns.py +5 -3
- teradataml/data/docs/sqle/docs_17_20/KMeans.py +48 -1
- teradataml/data/docs/sqle/docs_17_20/NonLinearCombineFit.py +3 -2
- teradataml/data/docs/sqle/docs_17_20/OneHotEncodingFit.py +5 -0
- teradataml/data/docs/sqle/docs_17_20/OneHotEncodingTransform.py +6 -0
- teradataml/data/docs/sqle/docs_17_20/ROC.py +3 -2
- teradataml/data/docs/sqle/docs_17_20/SVMPredict.py +13 -2
- teradataml/data/docs/sqle/docs_17_20/ScaleFit.py +119 -1
- teradataml/data/docs/sqle/docs_17_20/ScaleTransform.py +93 -1
- teradataml/data/docs/sqle/docs_17_20/TDGLMPredict.py +163 -1
- teradataml/data/docs/sqle/docs_17_20/XGBoost.py +12 -4
- teradataml/data/docs/sqle/docs_17_20/XGBoostPredict.py +7 -1
- teradataml/data/docs/sqle/docs_17_20/ZTest.py +72 -7
- teradataml/data/glm_example.json +28 -1
- teradataml/data/housing_train_segment.csv +201 -0
- teradataml/data/insect2Cols.csv +61 -0
- teradataml/data/jsons/sqle/17.20/TD_ANOVA.json +99 -27
- teradataml/data/jsons/sqle/17.20/TD_FTest.json +166 -83
- teradataml/data/jsons/sqle/17.20/TD_GLM.json +90 -14
- teradataml/data/jsons/sqle/17.20/TD_GLMPREDICT.json +48 -5
- teradataml/data/jsons/sqle/17.20/TD_GetFutileColumns.json +5 -3
- teradataml/data/jsons/sqle/17.20/TD_KMeans.json +31 -11
- teradataml/data/jsons/sqle/17.20/TD_NonLinearCombineFit.json +3 -2
- teradataml/data/jsons/sqle/17.20/TD_ROC.json +2 -1
- teradataml/data/jsons/sqle/17.20/TD_SVM.json +16 -16
- teradataml/data/jsons/sqle/17.20/TD_SVMPredict.json +19 -1
- teradataml/data/jsons/sqle/17.20/TD_ScaleFit.json +168 -15
- teradataml/data/jsons/sqle/17.20/TD_ScaleTransform.json +50 -1
- teradataml/data/jsons/sqle/17.20/TD_XGBoost.json +25 -7
- teradataml/data/jsons/sqle/17.20/TD_XGBoostPredict.json +17 -4
- teradataml/data/jsons/sqle/17.20/TD_ZTest.json +157 -80
- teradataml/data/kmeans_example.json +5 -0
- teradataml/data/kmeans_table.csv +10 -0
- teradataml/data/onehot_encoder_train.csv +4 -0
- teradataml/data/openml_example.json +29 -0
- teradataml/data/scale_attributes.csv +3 -0
- teradataml/data/scale_example.json +52 -1
- teradataml/data/scale_input_part_sparse.csv +31 -0
- teradataml/data/scale_input_partitioned.csv +16 -0
- teradataml/data/scale_input_sparse.csv +11 -0
- teradataml/data/scale_parameters.csv +3 -0
- teradataml/data/scripts/deploy_script.py +20 -1
- teradataml/data/scripts/sklearn/sklearn_fit.py +23 -27
- teradataml/data/scripts/sklearn/sklearn_fit_predict.py +20 -28
- teradataml/data/scripts/sklearn/sklearn_function.template +13 -18
- teradataml/data/scripts/sklearn/sklearn_model_selection_split.py +23 -33
- teradataml/data/scripts/sklearn/sklearn_neighbors.py +18 -27
- teradataml/data/scripts/sklearn/sklearn_score.py +20 -29
- teradataml/data/scripts/sklearn/sklearn_transform.py +30 -38
- teradataml/data/teradataml_example.json +77 -0
- teradataml/data/ztest_example.json +16 -0
- teradataml/dataframe/copy_to.py +8 -3
- teradataml/dataframe/data_transfer.py +120 -61
- teradataml/dataframe/dataframe.py +102 -17
- teradataml/dataframe/dataframe_utils.py +47 -9
- teradataml/dataframe/fastload.py +272 -89
- teradataml/dataframe/sql.py +84 -0
- teradataml/dbutils/dbutils.py +2 -2
- teradataml/lib/aed_0_1.dll +0 -0
- teradataml/opensource/sklearn/_sklearn_wrapper.py +102 -55
- teradataml/options/__init__.py +13 -4
- teradataml/options/configure.py +27 -6
- teradataml/scriptmgmt/UserEnv.py +19 -16
- teradataml/scriptmgmt/lls_utils.py +117 -14
- teradataml/table_operators/Script.py +2 -3
- teradataml/table_operators/TableOperator.py +58 -10
- teradataml/utils/validators.py +40 -2
- {teradataml-20.0.0.0.dist-info → teradataml-20.0.0.1.dist-info}/METADATA +78 -6
- {teradataml-20.0.0.0.dist-info → teradataml-20.0.0.1.dist-info}/RECORD +108 -90
- {teradataml-20.0.0.0.dist-info → teradataml-20.0.0.1.dist-info}/WHEEL +0 -0
- {teradataml-20.0.0.0.dist-info → teradataml-20.0.0.1.dist-info}/top_level.txt +0 -0
- {teradataml-20.0.0.0.dist-info → teradataml-20.0.0.1.dist-info}/zip-safe +0 -0
teradataml/LICENSE-3RD-PARTY.pdf
CHANGED
|
Binary file
|
teradataml/LICENSE.pdf
CHANGED
|
Binary file
|
teradataml/README.md
CHANGED
|
@@ -16,6 +16,77 @@ Copyright 2024, Teradata. All Rights Reserved.
|
|
|
16
16
|
* [License](#license)
|
|
17
17
|
|
|
18
18
|
## Release Notes:
|
|
19
|
+
#### teradataml 20.00.00.01
|
|
20
|
+
* teradataml no longer supports Python versions less than 3.8.
|
|
21
|
+
|
|
22
|
+
* ##### New Features/Functionality
|
|
23
|
+
* ##### Personal Access Token (PAT) support in teradataml
|
|
24
|
+
* `set_auth_token()` - teradataml now supports authentication via PAT in addition to
|
|
25
|
+
OAuth 2.0 Device Authorization Grant (formerly known as the Device Flow).
|
|
26
|
+
* It accepts UES URL, Personal AccessToken (PAT) and Private Key file generated from VantageCloud Lake Console
|
|
27
|
+
and optional argument `username` and `expiration_time` in seconds.
|
|
28
|
+
|
|
29
|
+
* ##### Updates
|
|
30
|
+
* ##### teradataml: SQLE Engine Analytic Functions
|
|
31
|
+
* `ANOVA()`
|
|
32
|
+
* New arguments added: `group_name_column`, `group_value_name`, `group_names`, `num_groups` for data containing group values and group names.
|
|
33
|
+
* `FTest()`
|
|
34
|
+
* New arguments added: `sample_name_column`, `sample_name_value`, `first_sample_name`, `second_sample_name`.
|
|
35
|
+
* `GLM()`
|
|
36
|
+
* Supports stepwise regression and accept new arguments `stepwise_direction`, `max_steps_num` and `initial_stepwise_columns`.
|
|
37
|
+
* New arguments added: `attribute_data`, `parameter_data`, `iteration_mode` and `partition_column`.
|
|
38
|
+
* `GetFutileColumns()`
|
|
39
|
+
* Arguments `category_summary_column` and `threshold_value` are now optional.
|
|
40
|
+
* `KMeans()`
|
|
41
|
+
* New argument added: `initialcentroids_method`.
|
|
42
|
+
* `NonLinearCombineFit()`
|
|
43
|
+
* Argument `result_column` is now optional.
|
|
44
|
+
* `ROC()`
|
|
45
|
+
* Argument `positive_class` is now optional.
|
|
46
|
+
* `SVMPredict()`
|
|
47
|
+
* New argument added: `model_type`.
|
|
48
|
+
* `ScaleFit()`
|
|
49
|
+
* New arguments added: `ignoreinvalid_locationscale`, `unused_attributes`, `attribute_name_column`, `attribute_value_column`.
|
|
50
|
+
* Arguments `attribute_name_column`, `attribute_value_column` and `target_attributes` are supported for sparse input.
|
|
51
|
+
* Arguments `attribute_data`, `parameter_data` and `partition_column` are supported for partitioning.
|
|
52
|
+
* `ScaleTransform()`
|
|
53
|
+
* New arguments added: `attribute_name_column` and `attribute_value_column` support for sparse input.
|
|
54
|
+
* `TDGLMPredict()`
|
|
55
|
+
* New arguments added: `family` and `partition_column`.
|
|
56
|
+
* `XGBoost()`
|
|
57
|
+
* New argument `base_score` is added for initial prediction value for all data points.
|
|
58
|
+
* `XGBoostPredict()`
|
|
59
|
+
* New argument `detailed` is added for detailed information of each prediction.
|
|
60
|
+
* `ZTest()`
|
|
61
|
+
* New arguments added: `sample_name_column`, `sample_value_column`, `first_sample_name` and `second_sample_name`.
|
|
62
|
+
* ##### teradataml: AutoML
|
|
63
|
+
* `AutoML()`, `AutoRegressor()` and `AutoClassifier()`
|
|
64
|
+
* New argument `max_models` is added as an early stopping criterion to limit the maximum number of models to be trained.
|
|
65
|
+
* ##### teradataml: DataFrame functions
|
|
66
|
+
* `DataFrame.agg()`
|
|
67
|
+
* Accepts ColumnExpressions and list of ColumnExpressions as arguments.
|
|
68
|
+
* ##### teradataml: General Functions
|
|
69
|
+
* Data Transfer Utility
|
|
70
|
+
* `fastload()` - Improved error and warning table handling with below-mentioned new arguments.
|
|
71
|
+
* `err_staging_db`
|
|
72
|
+
* `err_tbl_name`
|
|
73
|
+
* `warn_tbl_name`
|
|
74
|
+
* `err_tbl_1_suffix`
|
|
75
|
+
* `err_tbl_2_suffix`
|
|
76
|
+
* `fastload()` - Change in behaviour of `save_errors` argument.
|
|
77
|
+
When `save_errors` is set to `True`, error information will be available in two persistent tables `ERR_1` and `ERR_2`.
|
|
78
|
+
When `save_errors` is set to `False`, error information will be available in single pandas dataframe.
|
|
79
|
+
* Garbage collector location is now configurable.
|
|
80
|
+
User can set configure.local_storage to a desired location.
|
|
81
|
+
|
|
82
|
+
* ##### Bug Fixes
|
|
83
|
+
* UAF functions now work if the database name has special characters.
|
|
84
|
+
* OpensourceML can now read and process NULL/nan values.
|
|
85
|
+
* Boolean values output will now be returned as VARBYTE column with 0 or 1 values in OpensourceML.
|
|
86
|
+
* Fixed bug for `Apply`'s `deploy()`.
|
|
87
|
+
* Issue with volatile table creation is fixed where it is created in the right database, i.e., user's spool space, regardless of the temp database specified.
|
|
88
|
+
* `ColumnTransformer` function now processes its arguments in the order they are passed.
|
|
89
|
+
|
|
19
90
|
#### teradataml 20.00.00.00
|
|
20
91
|
* ##### New Features/Functionality
|
|
21
92
|
* ###### teradataml OpenML: Run Opensource packages through Teradata Vantage
|
teradataml/_version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# ##################################################################
|
|
2
2
|
#
|
|
3
|
-
# Copyright
|
|
3
|
+
# Copyright 2024 Teradata. All rights reserved.
|
|
4
4
|
# TERADATA CONFIDENTIAL AND TRADE SECRET
|
|
5
5
|
#
|
|
6
6
|
# Primary Owner: Pankaj Purandare (PankajVinod.Purandare@teradata.com)
|
|
@@ -8,4 +8,4 @@
|
|
|
8
8
|
#
|
|
9
9
|
# ##################################################################
|
|
10
10
|
|
|
11
|
-
version = "20.00.00.
|
|
11
|
+
version = "20.00.00.01"
|
|
@@ -86,6 +86,9 @@ class _AnlyticFunctionExecutor:
|
|
|
86
86
|
# Initialize FuncSpecialCaseHandler.
|
|
87
87
|
self._spl_func_obj = FuncSpecialCaseHandler(self.func_name)
|
|
88
88
|
|
|
89
|
+
# Initialize database object type.
|
|
90
|
+
self.db_object_type = TeradataConstants.TERADATA_VIEW
|
|
91
|
+
|
|
89
92
|
@staticmethod
|
|
90
93
|
def _validate_analytic_function_argument(func_arg_name, func_arg_value, argument, additional_valid_types=None):
|
|
91
94
|
"""
|
|
@@ -178,7 +181,7 @@ class _AnlyticFunctionExecutor:
|
|
|
178
181
|
|
|
179
182
|
EXAMPLES:
|
|
180
183
|
self._execute_query()
|
|
181
|
-
"""
|
|
184
|
+
"""
|
|
182
185
|
# Generate STDOUT table name and add it to the output table list.
|
|
183
186
|
func_params = self._get_generate_temp_table_params(persist=persist, volatile=volatile)
|
|
184
187
|
sqlmr_stdout_temp_tablename = UtilFuncs._generate_temp_table_name(**func_params)
|
|
@@ -248,25 +251,18 @@ class _AnlyticFunctionExecutor:
|
|
|
248
251
|
self._get_generate_temp_table_params(True, True)
|
|
249
252
|
"""
|
|
250
253
|
use_default_database = True
|
|
251
|
-
db_object_type = TeradataConstants.TERADATA_VIEW
|
|
252
254
|
prefix = "td_sqlmr_out_"
|
|
253
255
|
gc_on_quit = True
|
|
254
256
|
|
|
255
|
-
# If
|
|
256
|
-
#
|
|
257
|
-
# reading from a view created on output, then 'db_object_type' should be "table".
|
|
258
|
-
if len(self._metadata.output_tables) > 0 or not self._metadata._is_view_supported:
|
|
259
|
-
db_object_type = TeradataConstants.TERADATA_TABLE
|
|
260
|
-
|
|
261
|
-
# If result is to be persisted or if the table is a volaile table then, db_object_type
|
|
262
|
-
# should be "table" and it must not be Garbage collected.
|
|
257
|
+
# If result is to be persisted or if the table is a volaile table then,
|
|
258
|
+
# it must not be Garbage collected.
|
|
263
259
|
if persist or volatile:
|
|
264
260
|
gc_on_quit = False
|
|
265
|
-
db_object_type = TeradataConstants.TERADATA_TABLE
|
|
266
261
|
prefix = "td_sqlmr_{}_out_".format("persist" if persist else "volatile")
|
|
262
|
+
use_default_database = False if volatile else True
|
|
267
263
|
|
|
268
264
|
return {"use_default_database": use_default_database,
|
|
269
|
-
"table_type": db_object_type,
|
|
265
|
+
"table_type": self.db_object_type,
|
|
270
266
|
"prefix": prefix,
|
|
271
267
|
"gc_on_quit": gc_on_quit}
|
|
272
268
|
|
|
@@ -694,10 +690,26 @@ class _AnlyticFunctionExecutor:
|
|
|
694
690
|
MessageCodes.CANNOT_USE_TOGETHER_WITH)
|
|
695
691
|
|
|
696
692
|
self._dyn_cls_data_members.update(kwargs)
|
|
697
|
-
|
|
693
|
+
|
|
694
|
+
# If function produces output tables, i.e., function has output table arguments,
|
|
695
|
+
# then 'db_object_type' should be "table" or if analytic function does not support
|
|
696
|
+
# reading from a view created on output, then 'db_object_type' should be "table".
|
|
697
|
+
# If result is to be persisted or if the table is a volaile table then, db_object_type
|
|
698
|
+
# should be "table" else it should be "view".
|
|
699
|
+
self.db_object_type = (
|
|
700
|
+
TeradataConstants.TERADATA_VOLATILE_TABLE if volatile
|
|
701
|
+
else TeradataConstants.TERADATA_TABLE if len(self._metadata.output_tables) > 0 \
|
|
702
|
+
or not self._metadata._is_view_supported or persist
|
|
703
|
+
else TeradataConstants.TERADATA_VIEW
|
|
704
|
+
)
|
|
698
705
|
if not skip_input_arg_processing:
|
|
699
706
|
self._process_input_argument(**kwargs)
|
|
700
707
|
|
|
708
|
+
# check func_name is GLM and data_partition_column, data_hash_column, local_order_data are passed
|
|
709
|
+
if self.func_name in ['GLM', 'TDGLMPredict'] and \
|
|
710
|
+
any(key in kwargs for key in ['data_partition_column', 'data_hash_column', 'local_order_data']):
|
|
711
|
+
skip_output_arg_processing = True
|
|
712
|
+
|
|
701
713
|
if not skip_output_arg_processing:
|
|
702
714
|
self._process_output_argument(**kwargs)
|
|
703
715
|
|
|
@@ -856,22 +868,34 @@ class _SQLEFunctionExecutor(_AnlyticFunctionExecutor):
|
|
|
856
868
|
EXAMPLES:
|
|
857
869
|
self._get_input_args()
|
|
858
870
|
"""
|
|
871
|
+
sort_order = list(kwargs.keys())
|
|
872
|
+
input_table_dict = {}
|
|
873
|
+
|
|
859
874
|
for _inp_attribute in self._metadata.input_tables:
|
|
860
875
|
input_table_arg = _inp_attribute.get_lang_name()
|
|
861
|
-
yield input_table_arg, _inp_attribute
|
|
862
876
|
|
|
863
|
-
#
|
|
877
|
+
# Store the first argument directly into the dictionary
|
|
878
|
+
input_table_dict[input_table_arg] = _inp_attribute
|
|
879
|
+
|
|
880
|
+
# Check if SQL function allows multiple values as input.
|
|
864
881
|
if _inp_attribute.allows_lists():
|
|
865
882
|
_index = 1
|
|
866
883
|
while True:
|
|
867
884
|
_input_table_arg = "{}{}".format(input_table_arg, _index)
|
|
868
|
-
# If the corresponding object is available in kwargs, then extract it.
|
|
869
|
-
# Otherwise, stop looking for multiple arguments and proceed for next attribute.
|
|
870
885
|
if _input_table_arg in kwargs:
|
|
871
|
-
|
|
872
|
-
_index
|
|
886
|
+
input_table_dict[_input_table_arg] = _inp_attribute
|
|
887
|
+
_index += 1
|
|
873
888
|
else:
|
|
874
889
|
break
|
|
890
|
+
|
|
891
|
+
# For ColumnTransformer, yield the input arguments in the order they are passed.
|
|
892
|
+
if self.func_name == "ColumnTransformer":
|
|
893
|
+
for key in sort_order:
|
|
894
|
+
if key in input_table_dict:
|
|
895
|
+
yield key, input_table_dict[key]
|
|
896
|
+
else:
|
|
897
|
+
for key in input_table_dict:
|
|
898
|
+
yield key, input_table_dict[key]
|
|
875
899
|
|
|
876
900
|
def _process_input_argument(self, **kwargs):
|
|
877
901
|
"""
|
|
@@ -1707,14 +1731,17 @@ class _UAFFunctionExecutor(_SQLEFunctionExecutor):
|
|
|
1707
1731
|
self._get_generate_temp_table_params(True, True)
|
|
1708
1732
|
"""
|
|
1709
1733
|
prefix = "td_uaf_out_"
|
|
1710
|
-
|
|
1734
|
+
gc_on_quit = True
|
|
1711
1735
|
# If result is to be persisted then, it must not be Garbage collected.
|
|
1712
|
-
|
|
1736
|
+
if persist or volatile:
|
|
1737
|
+
gc_on_quit = False
|
|
1738
|
+
prefix = "td_uaf_{}_out_".format("persist" if persist else "volatile")
|
|
1713
1739
|
|
|
1714
|
-
return {"table_type":
|
|
1740
|
+
return {"table_type": self.db_object_type,
|
|
1715
1741
|
"prefix": prefix,
|
|
1716
1742
|
"gc_on_quit": gc_on_quit,
|
|
1717
|
-
"databasename": output_db if output_db else _get_context_temp_databasename(
|
|
1743
|
+
"databasename": output_db if output_db else _get_context_temp_databasename(
|
|
1744
|
+
table_type=self.db_object_type)}
|
|
1718
1745
|
|
|
1719
1746
|
def _process_output_argument(self, **kwargs):
|
|
1720
1747
|
"""
|
|
@@ -1762,7 +1789,7 @@ class _UAFFunctionExecutor(_SQLEFunctionExecutor):
|
|
|
1762
1789
|
# If database name is not provided by user, get the default database name
|
|
1763
1790
|
# else use user provided database name.
|
|
1764
1791
|
db_name = output_db_name if output_db_name is not None else \
|
|
1765
|
-
_get_context_temp_databasename()
|
|
1792
|
+
_get_context_temp_databasename(table_type=self.db_object_type)
|
|
1766
1793
|
|
|
1767
1794
|
# Get the fully qualified table name.
|
|
1768
1795
|
table_name = "{}.{}".format(UtilFuncs._teradata_quote_arg(db_name,
|
|
@@ -608,12 +608,16 @@ class _Evaluate:
|
|
|
608
608
|
if self.get_function_name() == "NaiveBayesTextClassifierTrainer":
|
|
609
609
|
return True
|
|
610
610
|
# name of argument is model_type for most of the functions but for some it is different
|
|
611
|
-
if "model_type" not in kwargs:
|
|
611
|
+
if "model_type" not in kwargs and "tree_type" not in kwargs:
|
|
612
612
|
arg_name = self.get_arg_name()
|
|
613
613
|
model_type = getattr(self.obj, arg_name)
|
|
614
|
-
|
|
614
|
+
if self.get_function_name() == "DecisionForest":
|
|
615
|
+
kwargs["tree_type"] = model_type
|
|
616
|
+
else:
|
|
617
|
+
kwargs["model_type"] = model_type
|
|
615
618
|
|
|
616
|
-
if kwargs["model_type"].lower() == "binomial" or kwargs["model_type"].lower() == "classification"
|
|
619
|
+
if ("model_type" in kwargs and (kwargs["model_type"].lower() == "binomial" or kwargs["model_type"].lower() == "classification")) \
|
|
620
|
+
or ( "tree_type" in kwargs and kwargs["tree_type"].lower() == "classification"):
|
|
617
621
|
is_classification_model = True
|
|
618
622
|
|
|
619
623
|
return is_classification_model
|
|
@@ -720,20 +724,10 @@ class _Evaluate:
|
|
|
720
724
|
kwargs["observation_column"] = response_column
|
|
721
725
|
kwargs["prediction_column"] = "Prediction" if "Prediction" in predict.result.columns else "prediction"
|
|
722
726
|
|
|
723
|
-
#
|
|
724
|
-
#
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
pre_col_name = kwargs["prediction_column"]
|
|
728
|
-
if res[kwargs["observation_column"]] != res[pre_col_name]:
|
|
729
|
-
# Converting the prediction column datatype to observation column datatype.
|
|
730
|
-
cast_cols_pre = {pre_col_name: getattr(predict.result, pre_col_name).expression.cast(
|
|
731
|
-
type_=res[kwargs["observation_column"]])}
|
|
732
|
-
# Update the predicted result dataframe.
|
|
733
|
-
predict.result = predict.result.assign(**cast_cols_pre)
|
|
734
|
-
|
|
735
|
-
# Update the num_labels by the number of unique values.
|
|
736
|
-
kwargs["num_labels"] = predict.result.drop_duplicate(kwargs["observation_column"]).shape[0]
|
|
727
|
+
# Update the num_labels by the number of unique values if
|
|
728
|
+
# Labels are not passed.
|
|
729
|
+
if "labels" not in kwargs:
|
|
730
|
+
kwargs["num_labels"] = predict.result.drop_duplicate(kwargs["observation_column"]).shape[0]
|
|
737
731
|
|
|
738
732
|
kwargs["data"] = predict.result
|
|
739
733
|
|