maxframe 0.1.0b5__cp37-cp37m-win_amd64.whl → 1.0.0__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (203) hide show
  1. maxframe/_utils.cp37-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +10 -4
  3. maxframe/config/config.py +68 -10
  4. maxframe/config/validators.py +42 -11
  5. maxframe/conftest.py +58 -14
  6. maxframe/core/__init__.py +2 -16
  7. maxframe/core/entity/__init__.py +1 -12
  8. maxframe/core/entity/executable.py +1 -1
  9. maxframe/core/entity/objects.py +46 -45
  10. maxframe/core/entity/output_types.py +0 -3
  11. maxframe/core/entity/tests/test_objects.py +43 -0
  12. maxframe/core/entity/tileables.py +5 -78
  13. maxframe/core/graph/__init__.py +2 -2
  14. maxframe/core/graph/builder/__init__.py +0 -1
  15. maxframe/core/graph/builder/base.py +5 -4
  16. maxframe/core/graph/builder/tileable.py +4 -4
  17. maxframe/core/graph/builder/utils.py +4 -8
  18. maxframe/core/graph/core.cp37-win_amd64.pyd +0 -0
  19. maxframe/core/graph/core.pyx +4 -4
  20. maxframe/core/graph/entity.py +9 -33
  21. maxframe/core/operator/__init__.py +2 -9
  22. maxframe/core/operator/base.py +3 -5
  23. maxframe/core/operator/objects.py +0 -9
  24. maxframe/core/operator/utils.py +55 -0
  25. maxframe/dataframe/__init__.py +1 -1
  26. maxframe/dataframe/arithmetic/around.py +5 -17
  27. maxframe/dataframe/arithmetic/core.py +15 -7
  28. maxframe/dataframe/arithmetic/docstring.py +7 -33
  29. maxframe/dataframe/arithmetic/equal.py +4 -2
  30. maxframe/dataframe/arithmetic/greater.py +4 -2
  31. maxframe/dataframe/arithmetic/greater_equal.py +4 -2
  32. maxframe/dataframe/arithmetic/less.py +2 -2
  33. maxframe/dataframe/arithmetic/less_equal.py +4 -2
  34. maxframe/dataframe/arithmetic/not_equal.py +4 -2
  35. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +39 -16
  36. maxframe/dataframe/core.py +31 -7
  37. maxframe/dataframe/datasource/date_range.py +2 -2
  38. maxframe/dataframe/datasource/read_odps_query.py +117 -23
  39. maxframe/dataframe/datasource/read_odps_table.py +6 -3
  40. maxframe/dataframe/datasource/tests/test_datasource.py +103 -8
  41. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  42. maxframe/dataframe/datastore/to_odps.py +28 -0
  43. maxframe/dataframe/extensions/__init__.py +5 -0
  44. maxframe/dataframe/extensions/flatjson.py +131 -0
  45. maxframe/dataframe/extensions/flatmap.py +317 -0
  46. maxframe/dataframe/extensions/reshuffle.py +1 -1
  47. maxframe/dataframe/extensions/tests/test_extensions.py +108 -3
  48. maxframe/dataframe/groupby/core.py +1 -1
  49. maxframe/dataframe/groupby/cum.py +0 -1
  50. maxframe/dataframe/groupby/fill.py +4 -1
  51. maxframe/dataframe/groupby/getitem.py +6 -0
  52. maxframe/dataframe/groupby/tests/test_groupby.py +5 -1
  53. maxframe/dataframe/groupby/transform.py +5 -1
  54. maxframe/dataframe/indexing/align.py +1 -1
  55. maxframe/dataframe/indexing/loc.py +6 -4
  56. maxframe/dataframe/indexing/rename.py +5 -28
  57. maxframe/dataframe/indexing/sample.py +0 -1
  58. maxframe/dataframe/indexing/set_index.py +68 -1
  59. maxframe/dataframe/initializer.py +11 -1
  60. maxframe/dataframe/merge/__init__.py +9 -1
  61. maxframe/dataframe/merge/concat.py +41 -31
  62. maxframe/dataframe/merge/merge.py +237 -3
  63. maxframe/dataframe/merge/tests/test_merge.py +126 -1
  64. maxframe/dataframe/misc/apply.py +5 -10
  65. maxframe/dataframe/misc/case_when.py +1 -1
  66. maxframe/dataframe/misc/describe.py +2 -2
  67. maxframe/dataframe/misc/drop_duplicates.py +8 -8
  68. maxframe/dataframe/misc/eval.py +4 -0
  69. maxframe/dataframe/misc/memory_usage.py +2 -2
  70. maxframe/dataframe/misc/pct_change.py +1 -83
  71. maxframe/dataframe/misc/tests/test_misc.py +33 -2
  72. maxframe/dataframe/misc/transform.py +1 -30
  73. maxframe/dataframe/misc/value_counts.py +4 -17
  74. maxframe/dataframe/missing/dropna.py +1 -1
  75. maxframe/dataframe/missing/fillna.py +5 -5
  76. maxframe/dataframe/operators.py +1 -17
  77. maxframe/dataframe/reduction/core.py +2 -2
  78. maxframe/dataframe/reduction/tests/test_reduction.py +2 -4
  79. maxframe/dataframe/sort/sort_values.py +1 -11
  80. maxframe/dataframe/statistics/corr.py +3 -3
  81. maxframe/dataframe/statistics/quantile.py +13 -19
  82. maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
  83. maxframe/dataframe/tests/test_initializer.py +33 -2
  84. maxframe/dataframe/utils.py +26 -11
  85. maxframe/dataframe/window/expanding.py +5 -3
  86. maxframe/dataframe/window/tests/test_expanding.py +2 -2
  87. maxframe/errors.py +13 -0
  88. maxframe/extension.py +12 -0
  89. maxframe/io/__init__.py +13 -0
  90. maxframe/io/objects/__init__.py +24 -0
  91. maxframe/io/objects/core.py +140 -0
  92. maxframe/io/objects/tensor.py +76 -0
  93. maxframe/io/objects/tests/__init__.py +13 -0
  94. maxframe/io/objects/tests/test_object_io.py +97 -0
  95. maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
  96. maxframe/{odpsio → io/odpsio}/arrow.py +42 -10
  97. maxframe/{odpsio → io/odpsio}/schema.py +38 -16
  98. maxframe/io/odpsio/tableio.py +719 -0
  99. maxframe/io/odpsio/tests/__init__.py +13 -0
  100. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +59 -22
  101. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
  102. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  103. maxframe/io/odpsio/volumeio.py +63 -0
  104. maxframe/learn/contrib/__init__.py +3 -1
  105. maxframe/learn/contrib/graph/__init__.py +15 -0
  106. maxframe/learn/contrib/graph/connected_components.py +215 -0
  107. maxframe/learn/contrib/graph/tests/__init__.py +13 -0
  108. maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
  109. maxframe/learn/contrib/llm/__init__.py +16 -0
  110. maxframe/learn/contrib/llm/core.py +54 -0
  111. maxframe/learn/contrib/llm/models/__init__.py +14 -0
  112. maxframe/learn/contrib/llm/models/dashscope.py +73 -0
  113. maxframe/learn/contrib/llm/multi_modal.py +42 -0
  114. maxframe/learn/contrib/llm/text.py +42 -0
  115. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  116. maxframe/learn/contrib/xgboost/core.py +87 -2
  117. maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
  118. maxframe/learn/contrib/xgboost/predict.py +29 -46
  119. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  120. maxframe/learn/contrib/xgboost/train.py +29 -18
  121. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  122. maxframe/lib/mmh3.cp37-win_amd64.pyd +0 -0
  123. maxframe/lib/mmh3.pyi +43 -0
  124. maxframe/lib/sparse/tests/test_sparse.py +15 -15
  125. maxframe/lib/wrapped_pickle.py +2 -1
  126. maxframe/opcodes.py +8 -0
  127. maxframe/protocol.py +154 -27
  128. maxframe/remote/core.py +4 -8
  129. maxframe/serialization/__init__.py +1 -0
  130. maxframe/serialization/core.cp37-win_amd64.pyd +0 -0
  131. maxframe/serialization/core.pxd +3 -0
  132. maxframe/serialization/core.pyi +3 -0
  133. maxframe/serialization/core.pyx +67 -26
  134. maxframe/serialization/exception.py +1 -1
  135. maxframe/serialization/pandas.py +52 -17
  136. maxframe/serialization/serializables/core.py +180 -15
  137. maxframe/serialization/serializables/field_type.py +4 -1
  138. maxframe/serialization/serializables/tests/test_serializable.py +54 -5
  139. maxframe/serialization/tests/test_serial.py +2 -1
  140. maxframe/session.py +9 -2
  141. maxframe/tensor/__init__.py +81 -2
  142. maxframe/tensor/arithmetic/isclose.py +1 -0
  143. maxframe/tensor/arithmetic/tests/test_arithmetic.py +22 -18
  144. maxframe/tensor/core.py +5 -136
  145. maxframe/tensor/datasource/array.py +3 -0
  146. maxframe/tensor/datasource/full.py +1 -1
  147. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  148. maxframe/tensor/indexing/flatnonzero.py +1 -1
  149. maxframe/tensor/indexing/getitem.py +2 -0
  150. maxframe/tensor/merge/__init__.py +2 -0
  151. maxframe/tensor/merge/concatenate.py +101 -0
  152. maxframe/tensor/merge/tests/test_merge.py +30 -1
  153. maxframe/tensor/merge/vstack.py +74 -0
  154. maxframe/tensor/{base → misc}/__init__.py +2 -0
  155. maxframe/tensor/{base → misc}/atleast_1d.py +1 -3
  156. maxframe/tensor/misc/atleast_2d.py +70 -0
  157. maxframe/tensor/misc/atleast_3d.py +85 -0
  158. maxframe/tensor/misc/tests/__init__.py +13 -0
  159. maxframe/tensor/{base → misc}/transpose.py +22 -18
  160. maxframe/tensor/{base → misc}/unique.py +3 -3
  161. maxframe/tensor/operators.py +1 -7
  162. maxframe/tensor/random/core.py +1 -1
  163. maxframe/tensor/reduction/count_nonzero.py +2 -1
  164. maxframe/tensor/reduction/mean.py +1 -0
  165. maxframe/tensor/reduction/nanmean.py +1 -0
  166. maxframe/tensor/reduction/nanvar.py +2 -0
  167. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  168. maxframe/tensor/reduction/var.py +2 -0
  169. maxframe/tensor/statistics/quantile.py +2 -2
  170. maxframe/tensor/utils.py +2 -22
  171. maxframe/tests/test_protocol.py +34 -0
  172. maxframe/tests/test_utils.py +0 -12
  173. maxframe/tests/utils.py +17 -2
  174. maxframe/typing_.py +4 -1
  175. maxframe/udf.py +8 -9
  176. maxframe/utils.py +106 -86
  177. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/METADATA +3 -3
  178. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/RECORD +197 -173
  179. maxframe_client/__init__.py +0 -1
  180. maxframe_client/clients/framedriver.py +4 -1
  181. maxframe_client/fetcher.py +81 -74
  182. maxframe_client/session/consts.py +3 -0
  183. maxframe_client/session/graph.py +8 -2
  184. maxframe_client/session/odps.py +194 -40
  185. maxframe_client/session/task.py +94 -39
  186. maxframe_client/tests/test_fetcher.py +21 -3
  187. maxframe_client/tests/test_session.py +109 -8
  188. maxframe/core/entity/chunks.py +0 -68
  189. maxframe/core/entity/fuse.py +0 -73
  190. maxframe/core/graph/builder/chunk.py +0 -430
  191. maxframe/odpsio/tableio.py +0 -322
  192. maxframe/odpsio/volumeio.py +0 -95
  193. maxframe_client/clients/spe.py +0 -104
  194. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  195. /maxframe/{tensor/base → dataframe/datastore}/tests/__init__.py +0 -0
  196. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  197. /maxframe/tensor/{base → misc}/astype.py +0 -0
  198. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  199. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  200. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  201. /maxframe/tensor/{base → misc}/where.py +0 -0
  202. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/WHEEL +0 -0
  203. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,53 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pytest
17
+
18
+ from ..... import dataframe as md
19
+ from .....dataframe.core import DataFrameData
20
+ from .....tensor.core import TensorData
21
+ from .. import connected_components
22
+
23
+
24
+ @pytest.fixture
25
+ def df1():
26
+ return md.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
27
+
28
+
29
+ @pytest.fixture
30
+ def df2():
31
+ return md.DataFrame(
32
+ [[1, "2"], [1, "2"]],
33
+ columns=["a", "b"],
34
+ )
35
+
36
+
37
+ def test_connected_components(df1, df2):
38
+ edges, flag = connected_components(df1, "a", "b")
39
+ assert edges.op.max_iter == 6
40
+ assert edges.shape == (np.nan, 2)
41
+ assert isinstance(edges.data, DataFrameData)
42
+ assert isinstance(flag.data, TensorData)
43
+ assert flag.shape == ()
44
+ assert "id" in edges.dtypes and "component" in edges.dtypes
45
+
46
+ with pytest.raises(ValueError):
47
+ connected_components(df1, "a", "x")
48
+
49
+ with pytest.raises(ValueError):
50
+ connected_components(df1, "a", "b", 0)
51
+
52
+ with pytest.raises(ValueError):
53
+ connected_components(df2, "a", "b")
@@ -0,0 +1,16 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from . import models, multi_modal, text
15
+
16
+ del models
@@ -0,0 +1,54 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from ....core.entity.output_types import OutputType
20
+ from ....core.operator.base import Operator
21
+ from ....core.operator.core import TileableOperatorMixin
22
+ from ....dataframe.utils import parse_index
23
+ from ....serialization.serializables.core import Serializable
24
+ from ....serialization.serializables.field import AnyField, DictField, StringField
25
+
26
+
27
+ class LLM(Serializable):
28
+ name = StringField("name", default=None)
29
+
30
+ def validate_params(self, params: Dict[str, Any]):
31
+ pass
32
+
33
+
34
+ class LLMOperator(Operator, TileableOperatorMixin):
35
+ model = AnyField("model", default=None)
36
+ prompt_template = AnyField("prompt_template", default=None)
37
+ params = DictField("params", default=None)
38
+
39
+ def __init__(self, output_types=None, **kw):
40
+ if output_types is None:
41
+ output_types = [OutputType.dataframe]
42
+ super().__init__(_output_types=output_types, **kw)
43
+
44
+ def __call__(self, data):
45
+ col_names = ["response", "success"]
46
+ columns = parse_index(pd.Index(col_names), store_data=True)
47
+ out_dtypes = pd.Series([np.dtype("O"), np.dtype("bool")], index=col_names)
48
+ return self.new_tileable(
49
+ inputs=[data],
50
+ dtypes=out_dtypes,
51
+ shape=(data.shape[0], len(col_names)),
52
+ index_value=data.index_value,
53
+ columns_value=columns,
54
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .dashscope import DashScopeMultiModalLLM, DashScopeTextLLM
@@ -0,0 +1,73 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ..... import opcodes
17
+ from .....serialization.serializables.core import Serializable
18
+ from .....serialization.serializables.field import StringField
19
+ from ..core import LLMOperator
20
+ from ..multi_modal import MultiModalLLM
21
+ from ..text import TextLLM
22
+
23
+
24
+ class DashScopeLLMMixin(Serializable):
25
+ __slots__ = ()
26
+
27
+ _not_supported_params = {"stream", "incremental_output"}
28
+
29
+ def validate_params(self, params: Dict[str, Any]):
30
+ for k in params.keys():
31
+ if k in self._not_supported_params:
32
+ raise ValueError(f"{k} is not supported")
33
+
34
+
35
+ class DashScopeTextLLM(TextLLM, DashScopeLLMMixin):
36
+ api_key_resource = StringField("api_key_resource", default=None)
37
+
38
+ def generate(
39
+ self,
40
+ data,
41
+ prompt_template: Dict[str, Any],
42
+ params: Dict[str, Any] = None,
43
+ ):
44
+ return DashScopeTextGenerationOperator(
45
+ model=self,
46
+ prompt_template=prompt_template,
47
+ params=params,
48
+ )(data)
49
+
50
+
51
+ class DashScopeMultiModalLLM(MultiModalLLM, DashScopeLLMMixin):
52
+ api_key_resource = StringField("api_key_resource", default=None)
53
+
54
+ def generate(
55
+ self,
56
+ data,
57
+ prompt_template: Dict[str, Any],
58
+ params: Dict[str, Any] = None,
59
+ ):
60
+ # TODO add precheck here
61
+ return DashScopeMultiModalGenerationOperator(
62
+ model=self,
63
+ prompt_template=prompt_template,
64
+ params=params,
65
+ )(data)
66
+
67
+
68
+ class DashScopeTextGenerationOperator(LLMOperator):
69
+ _op_type_ = opcodes.DASHSCOPE_TEXT_GENERATION
70
+
71
+
72
+ class DashScopeMultiModalGenerationOperator(LLMOperator):
73
+ _op_type_ = opcodes.DASHSCOPE_MULTI_MODAL_GENERATION
@@ -0,0 +1,42 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
+ from .core import LLM
18
+
19
+
20
+ class MultiModalLLM(LLM):
21
+ def generate(
22
+ self,
23
+ data,
24
+ prompt_template: Dict[str, Any],
25
+ params: Dict[str, Any] = None,
26
+ ):
27
+ raise NotImplementedError
28
+
29
+
30
+ def generate(
31
+ data,
32
+ model: MultiModalLLM,
33
+ prompt_template: Dict[str, Any],
34
+ params: Dict[str, Any] = None,
35
+ ):
36
+ if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
37
+ raise ValueError("data must be a maxframe dataframe or series object")
38
+ if not isinstance(model, MultiModalLLM):
39
+ raise ValueError("model must be a MultiModalLLM object")
40
+ params = params if params is not None else dict()
41
+ model.validate_params(params)
42
+ return model.generate(data, prompt_template, params)
@@ -0,0 +1,42 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict
15
+
16
+ from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
+ from .core import LLM
18
+
19
+
20
+ class TextLLM(LLM):
21
+ def generate(
22
+ self,
23
+ data,
24
+ prompt_template: Dict[str, Any],
25
+ params: Dict[str, Any] = None,
26
+ ):
27
+ raise NotImplementedError
28
+
29
+
30
+ def generate(
31
+ data,
32
+ model: TextLLM,
33
+ prompt_template: Dict[str, Any],
34
+ params: Dict[str, Any] = None,
35
+ ):
36
+ if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
37
+ raise ValueError("data must be a maxframe dataframe or series object")
38
+ if not isinstance(model, TextLLM):
39
+ raise ValueError("model must be a TextLLM object")
40
+ params = params if params is not None else dict()
41
+ model.validate_params(params)
42
+ return model.generate(data, prompt_template, params)
@@ -14,7 +14,8 @@
14
14
 
15
15
  import numpy as np
16
16
 
17
- from ....tensor import argmax
17
+ from ....tensor import argmax, transpose
18
+ from ....tensor.merge.vstack import _vstack
18
19
  from ..utils import make_import_error_func
19
20
  from .core import XGBScikitLearnBase, xgboost
20
21
 
@@ -42,7 +43,10 @@ else:
42
43
  sample_weight_eval_set=None,
43
44
  base_margin_eval_set=None,
44
45
  num_class=None,
46
+ **kw,
45
47
  ):
48
+ session = kw.pop("session", None)
49
+ run_kwargs = kw.pop("run_kwargs", dict())
46
50
  dtrain, evals = wrap_evaluation_matrices(
47
51
  None,
48
52
  X,
@@ -68,6 +72,8 @@ else:
68
72
  evals=evals,
69
73
  evals_result=self.evals_result_,
70
74
  num_class=num_class,
75
+ session=session,
76
+ run_kwargs=run_kwargs,
71
77
  )
72
78
  self._Booster = result
73
79
  return self
@@ -83,4 +89,22 @@ else:
83
89
  def predict_proba(self, data, ntree_limit=None, flag=False, **kw):
84
90
  if ntree_limit is not None:
85
91
  raise NotImplementedError("ntree_limit is not currently supported")
86
- return predict(self.get_booster(), data, flag=flag, **kw)
92
+ prediction = predict(self.get_booster(), data, flag=flag, **kw)
93
+ if len(prediction.shape) == 2 and prediction.shape[1] == self.n_classes_:
94
+ # multi-class
95
+ return prediction
96
+ if (
97
+ len(prediction.shape) == 2
98
+ and self.n_classes_ == 2
99
+ and prediction.shape[1] >= self.n_classes_
100
+ ):
101
+ # multi-label
102
+ return prediction
103
+ # binary logistic function
104
+ classone_probs = prediction
105
+ classzero_probs = 1.0 - classone_probs
106
+ return transpose(_vstack((classzero_probs, classone_probs)))
107
+
108
+ @property
109
+ def classes_(self) -> np.ndarray:
110
+ return np.arange(self.n_classes_)
@@ -12,15 +12,67 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Callable, List, Optional, Tuple
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple
16
16
 
17
17
  try:
18
18
  import xgboost
19
19
  except ImportError:
20
20
  xgboost = None
21
21
 
22
+ from ...core import Model, ModelData
22
23
  from .dmatrix import DMatrix
23
24
 
25
+
26
+ class BoosterData(ModelData):
27
+ __slots__ = ("_evals_result",)
28
+
29
+ _evals_result: Dict
30
+
31
+ def __init__(self, *args, evals_result=None, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self._evals_result = evals_result if evals_result is not None else dict()
34
+
35
+ def execute(self, session=None, **kw):
36
+ # The evals_result should be fetched when BoosterData.execute() is called.
37
+ result = super().execute(session=session, **kw)
38
+ if self.op.has_evals_result and self.key == self.op.outputs[0].key:
39
+ self._evals_result.update(self.op.outputs[1].fetch(session=session))
40
+ return result
41
+
42
+ def predict(
43
+ self,
44
+ data,
45
+ output_margin=False,
46
+ pred_leaf=False,
47
+ pred_contribs=False,
48
+ approx_contribs=False,
49
+ pred_interactions=False,
50
+ validate_features=True,
51
+ training=False,
52
+ iteration_range=None,
53
+ strict_shape=False,
54
+ ):
55
+ from .predict import predict
56
+
57
+ return predict(
58
+ self,
59
+ data,
60
+ output_margin=output_margin,
61
+ pred_leaf=pred_leaf,
62
+ pred_contribs=pred_contribs,
63
+ approx_contribs=approx_contribs,
64
+ pred_interactions=pred_interactions,
65
+ validate_features=validate_features,
66
+ training=training,
67
+ iteration_range=iteration_range,
68
+ strict_shape=strict_shape,
69
+ )
70
+
71
+
72
+ class Booster(Model):
73
+ pass
74
+
75
+
24
76
  if not xgboost:
25
77
  XGBScikitLearnBase = None
26
78
  else:
@@ -40,7 +92,9 @@ else:
40
92
  **kw,
41
93
  ):
42
94
  """
43
- Fit the regressor.
95
+ Fit the regressor. Note that fit() is an eager-execution
96
+ API. The call will be blocked until training finished.
97
+
44
98
  Parameters
45
99
  ----------
46
100
  X : array_like
@@ -72,6 +126,37 @@ else:
72
126
  """
73
127
  raise NotImplementedError
74
128
 
129
+ def evals_result(self, **kw) -> Dict:
130
+ """Return the evaluation results.
131
+
132
+ If **eval_set** is passed to the :py:meth:`fit` function, you can call
133
+ ``evals_result()`` to get evaluation results for all passed **eval_sets**. When
134
+ **eval_metric** is also passed to the :py:meth:`fit` function, the
135
+ **evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit`
136
+ function.
137
+
138
+ The returned evaluation result is a dictionary:
139
+
140
+ .. code-block:: python
141
+
142
+ {'validation_0': {'logloss': ['0.604835', '0.531479']},
143
+ 'validation_1': {'logloss': ['0.41965', '0.17686']}}
144
+
145
+ Note that evals_result() will be blocked until the train is finished.
146
+
147
+ Returns
148
+ -------
149
+ evals_result
150
+
151
+ """
152
+ result = super().evals_result()
153
+ if not self._Booster.op.has_evals_result or len(result) != 0:
154
+ return result
155
+ session = kw.pop("session", None)
156
+ run_kwargs = kw.pop("run_kwargs", dict())
157
+ self._Booster.execute(session=session, **run_kwargs)
158
+ return super().evals_result()
159
+
75
160
  def wrap_evaluation_matrices(
76
161
  missing: float,
77
162
  X: Any,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from .... import opcodes as OperandDef
16
+ from .... import opcodes
17
17
  from ....core.entity.output_types import get_output_types
18
18
  from ....core.operator.base import Operator
19
19
  from ....core.operator.core import TileableOperatorMixin
@@ -27,7 +27,7 @@ from ...utils import convert_to_tensor_or_dataframe
27
27
 
28
28
 
29
29
  class ToDMatrix(Operator, TileableOperatorMixin):
30
- _op_type_ = OperandDef.TO_DMATRIX
30
+ _op_type_ = opcodes.TO_DMATRIX
31
31
 
32
32
  data = KeyField("data", default=None)
33
33
  label = KeyField("label", default=None)
@@ -99,10 +99,7 @@ def check_array_like(y: TileableType, name: str) -> TileableType:
99
99
  y = convert_to_tensor_or_dataframe(y)
100
100
  if isinstance(y, DATAFRAME_TYPE):
101
101
  y = y.iloc[:, 0]
102
- y = astensor(y)
103
- if y.ndim != 1:
104
- raise ValueError(f"Expecting 1-d {name}, got: {y.ndim}-d")
105
- return y
102
+ return astensor(y)
106
103
 
107
104
 
108
105
  def to_dmatrix(
@@ -12,29 +12,30 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import pickle
16
15
 
17
16
  import numpy as np
18
- import pandas as pd
19
17
 
20
- from .... import opcodes as OperandDef
18
+ from .... import opcodes
21
19
  from ....core.entity.output_types import OutputType
22
20
  from ....core.operator.base import Operator
23
21
  from ....core.operator.core import TileableOperatorMixin
24
- from ....dataframe.utils import parse_index
25
- from ....serialization.serializables import BoolField, BytesField, KeyField, TupleField
26
- from ....tensor.core import TENSOR_TYPE, TensorOrder
22
+ from ....serialization.serializables import (
23
+ BoolField,
24
+ KeyField,
25
+ ReferenceField,
26
+ TupleField,
27
+ )
28
+ from ....tensor.core import TensorOrder
29
+ from .core import BoosterData
27
30
  from .dmatrix import check_data
28
31
 
29
32
 
30
33
  class XGBPredict(Operator, TileableOperatorMixin):
31
- _op_type_ = OperandDef.XGBOOST_PREDICT
34
+ _op_type_ = opcodes.XGBOOST_PREDICT
32
35
  output_dtype = np.dtype(np.float32)
33
36
 
34
37
  data = KeyField("data", default=None)
35
- model = BytesField(
36
- "model", on_serialize=pickle.dumps, on_deserialize=pickle.loads, default=None
37
- )
38
+ model = ReferenceField("model", reference_type=BoosterData, default=None)
38
39
  pred_leaf = BoolField("pred_leaf", default=False)
39
40
  pred_contribs = BoolField("pred_contribs", default=False)
40
41
  approx_contribs = BoolField("approx_contribs", default=False)
@@ -62,35 +63,12 @@ class XGBPredict(Operator, TileableOperatorMixin):
62
63
  else:
63
64
  shape = (self.data.shape[0],)
64
65
  inputs = [self.data, self.model]
65
- if self.output_types[0] == OutputType.tensor:
66
- # tensor
67
- return self.new_tileable(
68
- inputs,
69
- shape=shape,
70
- dtype=self.output_dtype,
71
- order=TensorOrder.C_ORDER,
72
- )
73
- elif self.output_types[0] == OutputType.dataframe:
74
- # dataframe
75
- dtypes = pd.DataFrame(
76
- np.random.rand(0, num_class), dtype=self.output_dtype
77
- ).dtypes
78
- return self.new_tileable(
79
- inputs,
80
- shape=shape,
81
- dtypes=dtypes,
82
- columns_value=parse_index(dtypes.index),
83
- index_value=self.data.index_value,
84
- )
85
- else:
86
- # series
87
- return self.new_tileable(
88
- inputs,
89
- shape=shape,
90
- index_value=self.data.index_value,
91
- name="predictions",
92
- dtype=self.output_dtype,
93
- )
66
+ return self.new_tileable(
67
+ inputs,
68
+ shape=shape,
69
+ dtype=self.output_dtype,
70
+ order=TensorOrder.C_ORDER,
71
+ )
94
72
 
95
73
 
96
74
  def predict(
@@ -107,16 +85,21 @@ def predict(
107
85
  strict_shape=False,
108
86
  flag=False,
109
87
  ):
88
+ """
89
+ Using MaxFrame XGBoost model to predict data.
90
+
91
+ Parameters
92
+ ----------
93
+ Parameters are the same as `xgboost.train`. The predict() is lazy-execution mode.
94
+
95
+ Returns
96
+ -------
97
+ results: Booster
98
+ """
110
99
  data = check_data(data)
111
100
  # TODO: check model datatype
112
101
 
113
- num_class = getattr(model.op, "num_class", None)
114
- if isinstance(data, TENSOR_TYPE):
115
- output_types = [OutputType.tensor]
116
- elif num_class is not None:
117
- output_types = [OutputType.dataframe]
118
- else:
119
- output_types = [OutputType.series]
102
+ output_types = [OutputType.tensor]
120
103
 
121
104
  iteration_range = iteration_range or (0, 0)
122
105
 
@@ -41,11 +41,6 @@ else:
41
41
  ):
42
42
  session = kw.pop("session", None)
43
43
  run_kwargs = kw.pop("run_kwargs", dict())
44
- if kw:
45
- raise TypeError(
46
- f"fit got an unexpected keyword argument '{next(iter(kw))}'"
47
- )
48
-
49
44
  dtrain, evals = wrap_evaluation_matrices(
50
45
  None,
51
46
  X,
@@ -57,6 +52,8 @@ else:
57
52
  base_margin_eval_set,
58
53
  )
59
54
  params = self.get_xgb_params()
55
+ if not params.get("objective"):
56
+ params["objective"] = "reg:squarederror"
60
57
  self.evals_result_ = dict()
61
58
  result = train(
62
59
  params,
@@ -71,8 +68,4 @@ else:
71
68
  return self
72
69
 
73
70
  def predict(self, data, **kw):
74
- session = kw.pop("session", None)
75
- run_kwargs = kw.pop("run_kwargs", None)
76
- return predict(
77
- self.get_booster(), data, session=session, run_kwargs=run_kwargs, **kw
78
- )
71
+ return predict(self.get_booster(), data, **kw)