fugue 0.8.2.dev1__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. fugue/__init__.py +9 -5
  2. fugue/_utils/interfaceless.py +1 -558
  3. fugue/_utils/io.py +2 -91
  4. fugue/_utils/registry.py +3 -2
  5. fugue/api.py +1 -0
  6. fugue/bag/bag.py +8 -4
  7. fugue/collections/__init__.py +0 -7
  8. fugue/collections/partition.py +21 -9
  9. fugue/constants.py +3 -1
  10. fugue/dataframe/__init__.py +7 -8
  11. fugue/dataframe/arrow_dataframe.py +1 -2
  12. fugue/dataframe/dataframe.py +17 -18
  13. fugue/dataframe/dataframe_iterable_dataframe.py +22 -6
  14. fugue/dataframe/function_wrapper.py +432 -0
  15. fugue/dataframe/iterable_dataframe.py +3 -0
  16. fugue/dataframe/utils.py +11 -79
  17. fugue/dataset/api.py +0 -4
  18. fugue/dev.py +47 -0
  19. fugue/execution/__init__.py +1 -5
  20. fugue/execution/api.py +36 -14
  21. fugue/execution/execution_engine.py +30 -4
  22. fugue/execution/factory.py +0 -6
  23. fugue/execution/native_execution_engine.py +44 -67
  24. fugue/extensions/_builtins/creators.py +4 -2
  25. fugue/extensions/_builtins/outputters.py +4 -3
  26. fugue/extensions/_builtins/processors.py +3 -3
  27. fugue/extensions/creator/convert.py +5 -2
  28. fugue/extensions/outputter/convert.py +2 -2
  29. fugue/extensions/processor/convert.py +3 -2
  30. fugue/extensions/transformer/convert.py +22 -9
  31. fugue/extensions/transformer/transformer.py +15 -1
  32. fugue/plugins.py +2 -0
  33. fugue/registry.py +0 -39
  34. fugue/sql/_utils.py +1 -1
  35. fugue/workflow/_checkpoint.py +1 -1
  36. fugue/workflow/api.py +13 -13
  37. fugue/workflow/module.py +30 -37
  38. fugue/workflow/workflow.py +6 -0
  39. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/METADATA +37 -23
  40. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/RECORD +112 -101
  41. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/WHEEL +1 -1
  42. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/entry_points.txt +2 -1
  43. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/top_level.txt +1 -0
  44. fugue_contrib/contrib.py +1 -0
  45. fugue_contrib/viz/_ext.py +7 -1
  46. fugue_dask/_io.py +0 -13
  47. fugue_dask/_utils.py +10 -4
  48. fugue_dask/dataframe.py +1 -2
  49. fugue_dask/execution_engine.py +45 -18
  50. fugue_dask/registry.py +8 -33
  51. fugue_duckdb/_io.py +8 -2
  52. fugue_duckdb/_utils.py +7 -2
  53. fugue_duckdb/dask.py +1 -1
  54. fugue_duckdb/dataframe.py +23 -19
  55. fugue_duckdb/execution_engine.py +19 -22
  56. fugue_duckdb/registry.py +11 -34
  57. fugue_ibis/dataframe.py +6 -10
  58. fugue_ibis/execution_engine.py +7 -1
  59. fugue_notebook/env.py +5 -10
  60. fugue_polars/__init__.py +2 -0
  61. fugue_polars/_utils.py +8 -0
  62. fugue_polars/polars_dataframe.py +234 -0
  63. fugue_polars/registry.py +86 -0
  64. fugue_ray/_constants.py +10 -1
  65. fugue_ray/_utils/dataframe.py +36 -9
  66. fugue_ray/_utils/io.py +2 -4
  67. fugue_ray/dataframe.py +16 -12
  68. fugue_ray/execution_engine.py +53 -32
  69. fugue_ray/registry.py +8 -32
  70. fugue_spark/_utils/convert.py +22 -11
  71. fugue_spark/_utils/io.py +0 -13
  72. fugue_spark/_utils/misc.py +27 -0
  73. fugue_spark/_utils/partition.py +11 -18
  74. fugue_spark/dataframe.py +26 -22
  75. fugue_spark/execution_engine.py +136 -54
  76. fugue_spark/registry.py +29 -78
  77. fugue_test/builtin_suite.py +36 -14
  78. fugue_test/dataframe_suite.py +9 -5
  79. fugue_test/execution_suite.py +100 -122
  80. fugue_version/__init__.py +1 -1
  81. tests/fugue/bag/test_array_bag.py +0 -9
  82. tests/fugue/collections/test_partition.py +10 -3
  83. tests/fugue/dataframe/test_function_wrapper.py +293 -0
  84. tests/fugue/dataframe/test_utils.py +2 -34
  85. tests/fugue/execution/test_factory.py +7 -9
  86. tests/fugue/execution/test_naive_execution_engine.py +35 -80
  87. tests/fugue/extensions/test_utils.py +12 -7
  88. tests/fugue/extensions/transformer/test_convert_cotransformer.py +1 -0
  89. tests/fugue/extensions/transformer/test_convert_output_cotransformer.py +1 -0
  90. tests/fugue/extensions/transformer/test_convert_transformer.py +2 -0
  91. tests/fugue/sql/test_workflow.py +1 -1
  92. tests/fugue/sql/test_workflow_parse.py +3 -5
  93. tests/fugue/utils/test_interfaceless.py +1 -325
  94. tests/fugue/utils/test_io.py +0 -80
  95. tests/fugue_dask/test_execution_engine.py +48 -0
  96. tests/fugue_dask/test_io.py +0 -55
  97. tests/fugue_duckdb/test_dataframe.py +2 -2
  98. tests/fugue_duckdb/test_execution_engine.py +16 -1
  99. tests/fugue_duckdb/test_utils.py +1 -1
  100. tests/fugue_ibis/test_dataframe.py +6 -3
  101. tests/fugue_polars/__init__.py +0 -0
  102. tests/fugue_polars/test_api.py +13 -0
  103. tests/fugue_polars/test_dataframe.py +82 -0
  104. tests/fugue_polars/test_transform.py +100 -0
  105. tests/fugue_ray/test_execution_engine.py +40 -4
  106. tests/fugue_spark/test_dataframe.py +0 -8
  107. tests/fugue_spark/test_execution_engine.py +50 -11
  108. tests/fugue_spark/test_importless.py +4 -4
  109. tests/fugue_spark/test_spark_connect.py +82 -0
  110. tests/fugue_spark/utils/test_convert.py +6 -8
  111. tests/fugue_spark/utils/test_io.py +0 -17
  112. fugue/_utils/register.py +0 -3
  113. fugue_test/_utils.py +0 -13
  114. {fugue-0.8.2.dev1.dist-info → fugue-0.8.4.dist-info}/LICENSE +0 -0
@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union
2
2
 
3
3
  import pyarrow as pa
4
4
  import ray
5
- from duckdb import DuckDBPyConnection, DuckDBPyRelation
5
+ from duckdb import DuckDBPyConnection
6
6
  from triad import Schema, assert_or_throw, to_uuid
7
7
  from triad.utils.threading import RunOnce
8
8
 
@@ -15,14 +15,14 @@ from fugue import (
15
15
  PartitionCursor,
16
16
  PartitionSpec,
17
17
  )
18
- from fugue.constants import KEYWORD_ROWCOUNT
18
+ from fugue.constants import KEYWORD_PARALLELISM, KEYWORD_ROWCOUNT
19
19
  from fugue.dataframe.arrow_dataframe import _build_empty_arrow
20
20
  from fugue_duckdb.dataframe import DuckDataFrame
21
21
  from fugue_duckdb.execution_engine import DuckExecutionEngine
22
22
 
23
- from ._constants import FUGUE_RAY_DEFAULT_BATCH_SIZE
23
+ from ._constants import FUGUE_RAY_DEFAULT_BATCH_SIZE, FUGUE_RAY_ZERO_COPY
24
24
  from ._utils.cluster import get_default_partitions, get_default_shuffle_partitions
25
- from ._utils.dataframe import add_partition_key
25
+ from ._utils.dataframe import add_coarse_partition_key, add_partition_key
26
26
  from ._utils.io import RayIO
27
27
  from .dataframe import RayDataFrame
28
28
 
@@ -45,6 +45,7 @@ class RayMapEngine(MapEngine):
45
45
  output_schema: Any,
46
46
  partition_spec: PartitionSpec,
47
47
  on_init: Optional[Callable[[int, DataFrame], Any]] = None,
48
+ map_func_format_hint: Optional[str] = None,
48
49
  ) -> DataFrame:
49
50
  if len(partition_spec.partition_by) == 0:
50
51
  return self._map(
@@ -71,12 +72,15 @@ class RayMapEngine(MapEngine):
71
72
  partition_spec: PartitionSpec,
72
73
  on_init: Optional[Callable[[int, DataFrame], Any]] = None,
73
74
  ) -> DataFrame:
74
- presort = partition_spec.presort
75
+ output_schema = Schema(output_schema)
76
+ input_schema = df.schema
77
+ presort = partition_spec.get_sorts(
78
+ input_schema, with_partition_keys=partition_spec.algo == "coarse"
79
+ )
75
80
  presort_tuples = [
76
81
  (k, "ascending" if v else "descending") for k, v in presort.items()
77
82
  ]
78
- output_schema = Schema(output_schema)
79
- input_schema = df.schema
83
+ cursor = partition_spec.get_cursor(input_schema, 0)
80
84
  on_init_once: Any = (
81
85
  None
82
86
  if on_init is None
@@ -89,7 +93,7 @@ class RayMapEngine(MapEngine):
89
93
  if adf.shape[0] == 0:
90
94
  return _build_empty_arrow(output_schema)
91
95
  adf = adf.remove_column(len(input_schema)) # remove partition key
92
- if len(presort_tuples) > 0:
96
+ if len(partition_spec.presort) > 0:
93
97
  if pa.__version__ < "7": # pragma: no cover
94
98
  idx = pa.compute.sort_indices(
95
99
  adf, options=pa.compute.SortOptions(presort_tuples)
@@ -100,8 +104,7 @@ class RayMapEngine(MapEngine):
100
104
  input_df = ArrowDataFrame(adf)
101
105
  if on_init_once is not None:
102
106
  on_init_once(0, input_df)
103
- cursor = partition_spec.get_cursor(input_schema, 0)
104
- cursor.set(input_df.peek_array(), 0, 0)
107
+ cursor.set(lambda: input_df.peek_array(), 0, 0)
105
108
  output_df = map_func(cursor, input_df)
106
109
  return output_df.as_arrow()
107
110
 
@@ -117,12 +120,20 @@ class RayMapEngine(MapEngine):
117
120
  _df = self.execution_engine.repartition( # type: ignore
118
121
  _df, PartitionSpec(num=n)
119
122
  )
120
- rdf, _ = add_partition_key(
121
- _df.native,
122
- keys=partition_spec.partition_by,
123
- input_schema=input_schema,
124
- output_key=_RAY_PARTITION_KEY,
125
- )
123
+ if partition_spec.algo != "coarse":
124
+ rdf, _ = add_partition_key(
125
+ _df.native,
126
+ keys=partition_spec.partition_by,
127
+ input_schema=input_schema,
128
+ output_key=_RAY_PARTITION_KEY,
129
+ )
130
+ else:
131
+ rdf = add_coarse_partition_key(
132
+ _df.native,
133
+ keys=partition_spec.partition_by,
134
+ output_key=_RAY_PARTITION_KEY,
135
+ bucket=_df.num_partitions,
136
+ )
126
137
 
127
138
  gdf = rdf.groupby(_RAY_PARTITION_KEY)
128
139
  sdf = gdf.map_groups(
@@ -142,6 +153,7 @@ class RayMapEngine(MapEngine):
142
153
  ) -> DataFrame:
143
154
  output_schema = Schema(output_schema)
144
155
  input_schema = df.schema
156
+ cursor = partition_spec.get_cursor(input_schema, 0)
145
157
  on_init_once: Any = (
146
158
  None
147
159
  if on_init is None
@@ -156,8 +168,7 @@ class RayMapEngine(MapEngine):
156
168
  input_df = ArrowDataFrame(adf)
157
169
  if on_init_once is not None:
158
170
  on_init_once(0, input_df)
159
- cursor = partition_spec.get_cursor(input_schema, 0)
160
- cursor.set(input_df.peek_array(), 0, 0)
171
+ cursor.set(lambda: input_df.peek_array(), 0, 0)
161
172
  output_df = map_func(cursor, input_df)
162
173
  return output_df.as_arrow()
163
174
 
@@ -175,15 +186,17 @@ class RayMapEngine(MapEngine):
175
186
  rdf = self.execution_engine.repartition( # type: ignore
176
187
  rdf, PartitionSpec(num=n)
177
188
  )
178
- batch_size = (
179
- self.conf.get_or_throw(FUGUE_RAY_DEFAULT_BATCH_SIZE, object)
180
- if FUGUE_RAY_DEFAULT_BATCH_SIZE in self.execution_engine.conf
181
- else "default"
182
- )
189
+ mb_args: Dict[str, Any] = {}
190
+ if FUGUE_RAY_DEFAULT_BATCH_SIZE in self.conf:
191
+ mb_args["batch_size"] = self.conf.get_or_throw(
192
+ FUGUE_RAY_DEFAULT_BATCH_SIZE, int
193
+ )
194
+ if ray.__version__ >= "2.3":
195
+ mb_args["zero_copy_batch"] = self.conf.get(FUGUE_RAY_ZERO_COPY, True)
183
196
  sdf = rdf.native.map_batches(
184
197
  _udf,
185
198
  batch_format="pyarrow",
186
- batch_size=batch_size,
199
+ **mb_args,
187
200
  **self.execution_engine._get_remote_args(), # type: ignore
188
201
  )
189
202
  return RayDataFrame(sdf, schema=output_schema, internal_schema=True)
@@ -202,6 +215,7 @@ class RayExecutionEngine(DuckExecutionEngine):
202
215
  ):
203
216
  if not ray.is_initialized(): # pragma: no cover
204
217
  ray.init()
218
+
205
219
  super().__init__(conf, connection)
206
220
  self._io = RayIO(self)
207
221
 
@@ -232,12 +246,15 @@ class RayExecutionEngine(DuckExecutionEngine):
232
246
 
233
247
  rdf = self._to_ray_df(df)
234
248
 
235
- num_funcs = {KEYWORD_ROWCOUNT: lambda: _persist_and_count(rdf)}
249
+ num_funcs = {
250
+ KEYWORD_ROWCOUNT: lambda: _persist_and_count(rdf),
251
+ KEYWORD_PARALLELISM: lambda: self.get_current_parallelism(),
252
+ }
236
253
  num = partition_spec.get_num_partitions(**num_funcs)
237
254
  pdf = rdf.native
238
255
 
239
256
  if num > 0:
240
- if partition_spec.algo in ["hash", "even"]:
257
+ if partition_spec.algo in ["hash", "even", "coarse"]:
241
258
  pdf = pdf.repartition(num)
242
259
  elif partition_spec.algo == "rand":
243
260
  pdf = pdf.repartition(num, shuffle=True)
@@ -264,6 +281,16 @@ class RayExecutionEngine(DuckExecutionEngine):
264
281
  return df if not as_local else df.as_local()
265
282
  return super().convert_yield_dataframe(df, as_local)
266
283
 
284
+ def union(self, df1: DataFrame, df2: DataFrame, distinct: bool = True) -> DataFrame:
285
+ if distinct:
286
+ return super().union(df1, df2, distinct)
287
+ assert_or_throw(
288
+ df1.schema == df2.schema, ValueError(f"{df1.schema} != {df2.schema}")
289
+ )
290
+ tdf1 = self._to_ray_df(df1)
291
+ tdf2 = self._to_ray_df(df2)
292
+ return RayDataFrame(tdf1.native.union(tdf2.native), df1.schema)
293
+
267
294
  def load_df( # type:ignore
268
295
  self,
269
296
  path: Union[str, List[str]],
@@ -312,12 +339,6 @@ class RayExecutionEngine(DuckExecutionEngine):
312
339
  ValueError("schema must be None when df is a DataFrame"),
313
340
  )
314
341
  return df
315
- if isinstance(df, DuckDBPyRelation):
316
- assert_or_throw(
317
- schema is None,
318
- ValueError("schema must be None when df is a DuckDBPyRelation"),
319
- )
320
- return DuckDataFrame(df)
321
342
  return RayDataFrame(df, schema)
322
343
 
323
344
  def _get_remote_args(self) -> Dict[str, Any]:
fugue_ray/registry.py CHANGED
@@ -1,19 +1,17 @@
1
- import inspect
2
- from typing import Any, Optional
1
+ from typing import Any
3
2
 
4
3
  import ray.data as rd
5
4
  from triad import run_at_def
6
5
 
7
- from fugue import DataFrame, is_pandas_or, register_execution_engine
8
- from fugue._utils.interfaceless import (
6
+ from fugue import DataFrame, register_execution_engine
7
+ from fugue.dev import (
9
8
  DataFrameParam,
10
9
  ExecutionEngineParam,
11
- SimpleAnnotationConverter,
12
- register_annotation_converter,
10
+ fugue_annotated_param,
11
+ is_pandas_or,
13
12
  )
14
13
  from fugue.plugins import as_fugue_dataset, infer_execution_engine
15
14
 
16
-
17
15
  from .dataframe import RayDataFrame
18
16
  from .execution_engine import RayExecutionEngine
19
17
 
@@ -36,34 +34,13 @@ def _register_engines() -> None:
36
34
  )
37
35
 
38
36
 
39
- def _register_annotation_converters() -> None:
40
- register_annotation_converter(
41
- 0.8,
42
- SimpleAnnotationConverter(
43
- RayExecutionEngine,
44
- lambda param: _RayExecutionEngineParam(param),
45
- ),
46
- )
47
- register_annotation_converter(
48
- 0.8,
49
- SimpleAnnotationConverter(rd.Dataset, lambda param: _RayDatasetParam(param)),
50
- )
51
-
52
-
37
+ @fugue_annotated_param(RayExecutionEngine)
53
38
  class _RayExecutionEngineParam(ExecutionEngineParam):
54
- def __init__(
55
- self,
56
- param: Optional[inspect.Parameter],
57
- ):
58
- super().__init__(
59
- param, annotation="RayExecutionEngine", engine_type=RayExecutionEngine
60
- )
39
+ pass
61
40
 
62
41
 
42
+ @fugue_annotated_param(rd.Dataset)
63
43
  class _RayDatasetParam(DataFrameParam):
64
- def __init__(self, param: Optional[inspect.Parameter]):
65
- super().__init__(param, annotation="ray.data.Dataset")
66
-
67
44
  def to_input_data(self, df: DataFrame, ctx: Any) -> Any:
68
45
  assert isinstance(ctx, RayExecutionEngine)
69
46
  return ctx._to_ray_df(df).native
@@ -81,4 +58,3 @@ class _RayDatasetParam(DataFrameParam):
81
58
  def _register() -> None:
82
59
  """Register Ray Execution Engine"""
83
60
  _register_engines()
84
- _register_annotation_converters()
@@ -1,29 +1,24 @@
1
1
  from typing import Any, Iterable, List, Tuple
2
2
 
3
+ import cloudpickle
4
+ import pandas as pd
3
5
  import pyarrow as pa
4
6
  import pyspark.sql as ps
5
7
  import pyspark.sql.types as pt
6
-
7
- try: # pyspark < 3
8
- from pyspark.sql.types import from_arrow_type, to_arrow_type # type: ignore
9
-
10
- # https://issues.apache.org/jira/browse/SPARK-29041
11
- pt._acceptable_types[pt.BinaryType] = (bytearray, bytes) # type: ignore # pragma: no cover # noqa: E501 # pylint: disable=line-too-long
12
- except ImportError: # pyspark >=3
13
- from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
14
-
15
8
  from pyarrow.types import is_list, is_struct, is_timestamp
9
+ from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
16
10
  from triad.collections import Schema
17
11
  from triad.utils.assertion import assert_arg_not_none, assert_or_throw
18
12
  from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP
19
13
  from triad.utils.schema import quote_name
14
+ from .misc import is_spark_dataframe
20
15
 
21
16
 
22
17
  def to_spark_schema(obj: Any) -> pt.StructType:
23
18
  assert_arg_not_none(obj, "schema")
24
19
  if isinstance(obj, pt.StructType):
25
20
  return obj
26
- if isinstance(obj, ps.DataFrame):
21
+ if is_spark_dataframe(obj):
27
22
  return obj.schema
28
23
  return _from_arrow_schema(Schema(obj).pa_schema)
29
24
 
@@ -32,7 +27,7 @@ def to_schema(obj: Any) -> Schema:
32
27
  assert_arg_not_none(obj, "obj")
33
28
  if isinstance(obj, pt.StructType):
34
29
  return Schema(_to_arrow_schema(obj))
35
- if isinstance(obj, ps.DataFrame):
30
+ if is_spark_dataframe(obj):
36
31
  return to_schema(obj.schema)
37
32
  return Schema(obj)
38
33
 
@@ -113,6 +108,22 @@ def to_type_safe_input(rows: Iterable[ps.Row], schema: Schema) -> Iterable[List[
113
108
  yield r
114
109
 
115
110
 
111
+ def to_pandas(df: ps.DataFrame) -> pd.DataFrame:
112
+ if pd.__version__ < "2" or not any(
113
+ isinstance(x.dataType, (pt.TimestampType, pt.TimestampNTZType))
114
+ for x in df.schema.fields
115
+ ):
116
+ return df.toPandas()
117
+
118
+ def serialize(dfs): # pragma: no cover
119
+ for df in dfs:
120
+ data = cloudpickle.dumps(df)
121
+ yield pd.DataFrame([[data]], columns=["data"])
122
+
123
+ sdf = df.mapInPandas(serialize, schema="data binary")
124
+ return pd.concat(cloudpickle.loads(x.data) for x in sdf.collect())
125
+
126
+
116
127
  # TODO: the following function always set nullable to true,
117
128
  # but should we use field.nullable?
118
129
  def _to_arrow_type(dt: pt.DataType) -> pa.DataType:
fugue_spark/_utils/io.py CHANGED
@@ -21,7 +21,6 @@ class SparkIO(object):
21
21
  "csv": self._load_csv,
22
22
  "parquet": self._load_parquet,
23
23
  "json": self._load_json,
24
- "avro": self._load_avro,
25
24
  }
26
25
 
27
26
  def load_df(
@@ -136,15 +135,3 @@ class SparkIO(object):
136
135
  return SparkDataFrame(reader.load(p))[columns]
137
136
  schema = Schema(columns)
138
137
  return SparkDataFrame(reader.load(p)[schema.names], schema)
139
-
140
- def _load_avro(self, p: List[str], columns: Any = None, **kwargs: Any) -> DataFrame:
141
- reader = self._session.read.format(
142
- "avro"
143
- ) # avro is an external data source that has built-in support since spark 2.4
144
- reader.options(**kwargs)
145
- if columns is None:
146
- return SparkDataFrame(reader.load(p))
147
- if isinstance(columns, list): # column names
148
- return SparkDataFrame(reader.load(p))[columns]
149
- schema = Schema(columns)
150
- return SparkDataFrame(reader.load(p)[schema.names], schema)
@@ -0,0 +1,27 @@
1
+ from typing import Any
2
+
3
+ try:
4
+ from pyspark.sql.connect.session import SparkSession as SparkConnectSession
5
+ from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
6
+ except ImportError: # pragma: no cover
7
+ SparkConnectSession = None
8
+ SparkConnectDataFrame = None
9
+ import pyspark.sql as ps
10
+
11
+
12
+ def is_spark_connect(session: Any) -> bool:
13
+ return SparkConnectSession is not None and isinstance(
14
+ session, (SparkConnectSession, SparkConnectDataFrame)
15
+ )
16
+
17
+
18
+ def is_spark_dataframe(df: Any) -> bool:
19
+ return isinstance(df, ps.DataFrame) or (
20
+ SparkConnectDataFrame is not None and isinstance(df, SparkConnectDataFrame)
21
+ )
22
+
23
+
24
+ def is_spark_session(session: Any) -> bool:
25
+ return isinstance(session, ps.SparkSession) or (
26
+ SparkConnectSession is not None and isinstance(session, SparkConnectSession)
27
+ )
@@ -1,11 +1,12 @@
1
- import random
2
1
  from typing import Any, Iterable, List
3
2
 
4
3
  import pyspark.sql as ps
5
- from fugue_spark._utils.convert import to_schema, to_spark_schema
4
+ import pyspark.sql.functions as psf
6
5
  from pyspark import RDD
7
6
  from pyspark.sql import SparkSession
8
- from pyspark.sql.functions import lit
7
+ import warnings
8
+ from .convert import to_schema, to_spark_schema
9
+ from .misc import is_spark_connect
9
10
 
10
11
  _PARTITION_DUMMY_KEY = "__partition_dummy_key__"
11
12
 
@@ -28,16 +29,10 @@ def rand_repartition(
28
29
  if len(cols) > 0 or num <= 1:
29
30
  return hash_repartition(session, df, num, cols)
30
31
 
31
- def _rand(rows: Iterable[Any], n: int) -> Iterable[Any]: # pragma: no cover
32
- for row in rows:
33
- yield (random.randint(0, n - 1), row)
34
-
35
- rdd = (
36
- df.rdd.mapPartitions(lambda r: _rand(r, num))
37
- .partitionBy(num, lambda k: k)
38
- .mapPartitions(_to_rows)
32
+ tdf = df.withColumn(
33
+ _PARTITION_DUMMY_KEY, (psf.rand(0) * psf.lit(2**15 - 1)).cast("long")
39
34
  )
40
- return session.createDataFrame(rdd, df.schema)
35
+ return tdf.repartition(num, _PARTITION_DUMMY_KEY)[df.schema.names]
41
36
 
42
37
 
43
38
  def even_repartition(
@@ -45,6 +40,9 @@ def even_repartition(
45
40
  ) -> ps.DataFrame:
46
41
  if num == 1:
47
42
  return _single_repartition(df)
43
+ if is_spark_connect(session): # pragma: no cover
44
+ warnings.warn("Even repartitioning is not supported by Spark Connect")
45
+ return hash_repartition(session, df, num, cols)
48
46
  if len(cols) == 0:
49
47
  if num == 0:
50
48
  return df
@@ -82,7 +80,7 @@ def even_repartition(
82
80
 
83
81
  def _single_repartition(df: ps.DataFrame) -> ps.DataFrame:
84
82
  return (
85
- df.withColumn(_PARTITION_DUMMY_KEY, lit(0))
83
+ df.withColumn(_PARTITION_DUMMY_KEY, psf.lit(0))
86
84
  .repartition(_PARTITION_DUMMY_KEY)
87
85
  .drop(_PARTITION_DUMMY_KEY)
88
86
  )
@@ -93,11 +91,6 @@ def _to_rows(rdd: Iterable[Any]) -> Iterable[Any]: # pragma: no cover
93
91
  yield item[1]
94
92
 
95
93
 
96
- def _to_rows_with_key(rdd: Iterable[Any]) -> Iterable[Any]: # pragma: no cover
97
- for item in rdd:
98
- yield list(item[1]) + [item[0]]
99
-
100
-
101
94
  def _zipWithIndex(rdd: RDD, to_rows: bool = False) -> RDD:
102
95
  """
103
96
  Modified from
fugue_spark/dataframe.py CHANGED
@@ -13,7 +13,6 @@ from fugue.dataframe import (
13
13
  DataFrame,
14
14
  IterableDataFrame,
15
15
  LocalBoundedDataFrame,
16
- LocalDataFrame,
17
16
  PandasDataFrame,
18
17
  )
19
18
  from fugue.exceptions import FugueDataFrameOperationError
@@ -31,7 +30,9 @@ from fugue.plugins import (
31
30
  rename,
32
31
  select_columns,
33
32
  )
34
- from fugue_spark._utils.convert import to_cast_expression, to_schema, to_type_safe_input
33
+
34
+ from ._utils.convert import to_cast_expression, to_pandas, to_schema, to_type_safe_input
35
+ from ._utils.misc import is_spark_connect, is_spark_dataframe
35
36
 
36
37
 
37
38
  class SparkDataFrame(DataFrame):
@@ -52,12 +53,12 @@ class SparkDataFrame(DataFrame):
52
53
 
53
54
  def __init__(self, df: Any = None, schema: Any = None): # noqa: C901
54
55
  self._lock = SerializableRLock()
55
- if isinstance(df, ps.DataFrame):
56
+ if is_spark_dataframe(df):
56
57
  if schema is not None:
57
58
  schema = to_schema(schema).assert_not_empty()
58
59
  has_cast, expr = to_cast_expression(df, schema, True)
59
60
  if has_cast:
60
- df = df.selectExpr(*expr)
61
+ df = df.selectExpr(*expr) # type: ignore
61
62
  else:
62
63
  schema = to_schema(df).assert_not_empty()
63
64
  self._native = df
@@ -90,12 +91,12 @@ class SparkDataFrame(DataFrame):
90
91
  def is_bounded(self) -> bool:
91
92
  return True
92
93
 
93
- def as_local(self) -> LocalDataFrame:
94
+ def as_local_bounded(self) -> LocalBoundedDataFrame:
94
95
  if any(pa.types.is_nested(t) for t in self.schema.types):
95
96
  data = list(to_type_safe_input(self.native.collect(), self.schema))
96
- res: LocalDataFrame = ArrayDataFrame(data, self.schema)
97
+ res: LocalBoundedDataFrame = ArrayDataFrame(data, self.schema)
97
98
  else:
98
- res = PandasDataFrame(self.native.toPandas(), self.schema)
99
+ res = PandasDataFrame(self.as_pandas(), self.schema)
99
100
  if self.has_metadata:
100
101
  res.reset_metadata(self.metadata)
101
102
  return res
@@ -127,7 +128,7 @@ class SparkDataFrame(DataFrame):
127
128
  return SparkDataFrame(self.native[schema.names])
128
129
 
129
130
  def as_pandas(self) -> pd.DataFrame:
130
- return self.native.toPandas()
131
+ return to_pandas(self.native)
131
132
 
132
133
  def rename(self, columns: Dict[str, str]) -> DataFrame:
133
134
  try:
@@ -151,6 +152,9 @@ class SparkDataFrame(DataFrame):
151
152
  def as_array_iterable(
152
153
  self, columns: Optional[List[str]] = None, type_safe: bool = False
153
154
  ) -> Iterable[Any]:
155
+ if is_spark_connect(self.native): # pragma: no cover
156
+ yield from self.as_array(columns, type_safe=type_safe)
157
+ return
154
158
  sdf = self._select_columns(columns)
155
159
  if not type_safe:
156
160
  for row in to_type_safe_input(sdf.native.rdd.toLocalIterator(), sdf.schema):
@@ -183,47 +187,47 @@ class SparkDataFrame(DataFrame):
183
187
  return SparkDataFrame(self.native.select(*columns))
184
188
 
185
189
 
186
- @is_df.candidate(lambda df: isinstance(df, ps.DataFrame))
190
+ @is_df.candidate(lambda df: is_spark_dataframe(df))
187
191
  def _spark_is_df(df: ps.DataFrame) -> bool:
188
192
  return True
189
193
 
190
194
 
191
- @get_num_partitions.candidate(lambda df: isinstance(df, ps.DataFrame))
195
+ @get_num_partitions.candidate(lambda df: is_spark_dataframe(df))
192
196
  def _spark_num_partitions(df: ps.DataFrame) -> int:
193
197
  return df.rdd.getNumPartitions()
194
198
 
195
199
 
196
- @count.candidate(lambda df: isinstance(df, ps.DataFrame))
200
+ @count.candidate(lambda df: is_spark_dataframe(df))
197
201
  def _spark_df_count(df: ps.DataFrame) -> int:
198
202
  return df.count()
199
203
 
200
204
 
201
- @is_bounded.candidate(lambda df: isinstance(df, ps.DataFrame))
205
+ @is_bounded.candidate(lambda df: is_spark_dataframe(df))
202
206
  def _spark_df_is_bounded(df: ps.DataFrame) -> bool:
203
207
  return True
204
208
 
205
209
 
206
- @is_empty.candidate(lambda df: isinstance(df, ps.DataFrame))
210
+ @is_empty.candidate(lambda df: is_spark_dataframe(df))
207
211
  def _spark_df_is_empty(df: ps.DataFrame) -> bool:
208
212
  return df.first() is None
209
213
 
210
214
 
211
- @is_local.candidate(lambda df: isinstance(df, ps.DataFrame))
215
+ @is_local.candidate(lambda df: is_spark_dataframe(df))
212
216
  def _spark_df_is_local(df: ps.DataFrame) -> bool:
213
217
  return False
214
218
 
215
219
 
216
- @as_local_bounded.candidate(lambda df: isinstance(df, ps.DataFrame))
220
+ @as_local_bounded.candidate(lambda df: is_spark_dataframe(df))
217
221
  def _spark_df_as_local(df: ps.DataFrame) -> pd.DataFrame:
218
- return df.toPandas()
222
+ return to_pandas(df)
219
223
 
220
224
 
221
- @get_column_names.candidate(lambda df: isinstance(df, ps.DataFrame))
225
+ @get_column_names.candidate(lambda df: is_spark_dataframe(df))
222
226
  def _get_spark_df_columns(df: ps.DataFrame) -> List[Any]:
223
227
  return df.columns
224
228
 
225
229
 
226
- @rename.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
230
+ @rename.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
227
231
  def _rename_spark_df(
228
232
  df: ps.DataFrame, columns: Dict[str, Any], as_fugue: bool = False
229
233
  ) -> ps.DataFrame:
@@ -233,7 +237,7 @@ def _rename_spark_df(
233
237
  return _adjust_df(_rename_spark_dataframe(df, columns), as_fugue=as_fugue)
234
238
 
235
239
 
236
- @drop_columns.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
240
+ @drop_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
237
241
  def _drop_spark_df_columns(
238
242
  df: ps.DataFrame, columns: List[str], as_fugue: bool = False
239
243
  ) -> Any:
@@ -245,7 +249,7 @@ def _drop_spark_df_columns(
245
249
  return _adjust_df(df[cols], as_fugue=as_fugue)
246
250
 
247
251
 
248
- @select_columns.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
252
+ @select_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
249
253
  def _select_spark_df_columns(
250
254
  df: ps.DataFrame, columns: List[Any], as_fugue: bool = False
251
255
  ) -> Any:
@@ -255,7 +259,7 @@ def _select_spark_df_columns(
255
259
  return _adjust_df(df[columns], as_fugue=as_fugue)
256
260
 
257
261
 
258
- @head.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
262
+ @head.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
259
263
  def _spark_df_head(
260
264
  df: ps.DataFrame,
261
265
  n: int,
@@ -265,7 +269,7 @@ def _spark_df_head(
265
269
  if columns is not None:
266
270
  df = df[columns]
267
271
  res = df.limit(n)
268
- return SparkDataFrame(res).as_local() if as_fugue else res.toPandas()
272
+ return SparkDataFrame(res).as_local() if as_fugue else to_pandas(res)
269
273
 
270
274
 
271
275
  def _rename_spark_dataframe(df: ps.DataFrame, names: Dict[str, Any]) -> ps.DataFrame: