fugue 0.8.2.dev4__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. fugue/__init__.py +0 -1
  2. fugue/_utils/io.py +2 -91
  3. fugue/api.py +1 -0
  4. fugue/collections/partition.py +12 -6
  5. fugue/constants.py +1 -1
  6. fugue/dataframe/__init__.py +1 -7
  7. fugue/dataframe/arrow_dataframe.py +1 -1
  8. fugue/dataframe/function_wrapper.py +2 -3
  9. fugue/dataframe/utils.py +10 -84
  10. fugue/execution/api.py +34 -12
  11. fugue/execution/native_execution_engine.py +33 -19
  12. fugue/extensions/_builtins/creators.py +4 -2
  13. fugue/extensions/_builtins/outputters.py +3 -3
  14. fugue/extensions/_builtins/processors.py +2 -3
  15. fugue/plugins.py +1 -0
  16. fugue/workflow/_checkpoint.py +1 -1
  17. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/METADATA +20 -10
  18. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/RECORD +67 -65
  19. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/entry_points.txt +2 -2
  20. fugue_contrib/viz/_ext.py +7 -1
  21. fugue_dask/_io.py +0 -13
  22. fugue_dask/_utils.py +10 -4
  23. fugue_dask/execution_engine.py +42 -16
  24. fugue_duckdb/_utils.py +7 -2
  25. fugue_duckdb/dask.py +1 -1
  26. fugue_duckdb/dataframe.py +17 -10
  27. fugue_duckdb/execution_engine.py +12 -22
  28. fugue_ibis/dataframe.py +2 -7
  29. fugue_notebook/env.py +5 -10
  30. fugue_polars/_utils.py +0 -40
  31. fugue_polars/polars_dataframe.py +22 -7
  32. fugue_ray/_constants.py +8 -1
  33. fugue_ray/_utils/dataframe.py +31 -4
  34. fugue_ray/_utils/io.py +2 -4
  35. fugue_ray/dataframe.py +13 -4
  36. fugue_ray/execution_engine.py +39 -21
  37. fugue_spark/_utils/convert.py +22 -11
  38. fugue_spark/_utils/io.py +0 -13
  39. fugue_spark/_utils/misc.py +27 -0
  40. fugue_spark/_utils/partition.py +11 -18
  41. fugue_spark/dataframe.py +24 -19
  42. fugue_spark/execution_engine.py +61 -35
  43. fugue_spark/registry.py +15 -3
  44. fugue_test/builtin_suite.py +7 -9
  45. fugue_test/dataframe_suite.py +7 -3
  46. fugue_test/execution_suite.py +100 -122
  47. fugue_version/__init__.py +1 -1
  48. tests/fugue/collections/test_partition.py +6 -3
  49. tests/fugue/dataframe/test_utils.py +2 -43
  50. tests/fugue/execution/test_naive_execution_engine.py +33 -0
  51. tests/fugue/utils/test_io.py +0 -80
  52. tests/fugue_dask/test_execution_engine.py +45 -0
  53. tests/fugue_dask/test_io.py +0 -55
  54. tests/fugue_duckdb/test_dataframe.py +2 -2
  55. tests/fugue_duckdb/test_utils.py +1 -1
  56. tests/fugue_polars/test_api.py +13 -0
  57. tests/fugue_polars/test_transform.py +11 -5
  58. tests/fugue_ray/test_execution_engine.py +32 -1
  59. tests/fugue_spark/test_dataframe.py +0 -8
  60. tests/fugue_spark/test_execution_engine.py +48 -10
  61. tests/fugue_spark/test_importless.py +4 -4
  62. tests/fugue_spark/test_spark_connect.py +82 -0
  63. tests/fugue_spark/utils/test_convert.py +6 -8
  64. tests/fugue_spark/utils/test_io.py +0 -17
  65. fugue_test/_utils.py +0 -13
  66. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/LICENSE +0 -0
  67. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/WHEEL +0 -0
  68. {fugue-0.8.2.dev4.dist-info → fugue-0.8.4.dist-info}/top_level.txt +0 -0
fugue_ray/dataframe.py CHANGED
@@ -17,6 +17,7 @@ from fugue.plugins import (
17
17
  rename,
18
18
  )
19
19
 
20
+ from ._constants import _ZERO_COPY
20
21
  from ._utils.dataframe import build_empty, get_dataset_format
21
22
 
22
23
 
@@ -140,7 +141,10 @@ class RayDataFrame(DataFrame):
140
141
  if cols == self.columns:
141
142
  return self
142
143
  rdf = self.native.map_batches(
143
- lambda b: b.select(cols), batch_format="pyarrow", **self._remote_args()
144
+ lambda b: b.select(cols),
145
+ batch_format="pyarrow",
146
+ **_ZERO_COPY,
147
+ **self._remote_args(),
144
148
  )
145
149
  return RayDataFrame(rdf, self.schema.extract(cols), internal_schema=True)
146
150
 
@@ -174,6 +178,7 @@ class RayDataFrame(DataFrame):
174
178
  rdf = self.native.map_batches(
175
179
  lambda b: b.rename_columns(new_cols),
176
180
  batch_format="pyarrow",
181
+ **_ZERO_COPY,
177
182
  **self._remote_args(),
178
183
  )
179
184
  return RayDataFrame(rdf, schema=new_schema, internal_schema=True)
@@ -188,7 +193,7 @@ class RayDataFrame(DataFrame):
188
193
  if self.schema == new_schema:
189
194
  return self
190
195
  rdf = self.native.map_batches(
191
- _alter, batch_format="pyarrow", **self._remote_args()
196
+ _alter, batch_format="pyarrow", **_ZERO_COPY, **self._remote_args()
192
197
  )
193
198
  return RayDataFrame(rdf, schema=new_schema, internal_schema=True)
194
199
 
@@ -231,7 +236,9 @@ class RayDataFrame(DataFrame):
231
236
  return ArrowDataFrame(table).alter_columns(schema).native # type: ignore
232
237
 
233
238
  return (
234
- rdf.map_batches(_alter, batch_format="pyarrow", **self._remote_args()),
239
+ rdf.map_batches(
240
+ _alter, batch_format="pyarrow", **_ZERO_COPY, **self._remote_args()
241
+ ),
235
242
  schema,
236
243
  )
237
244
 
@@ -273,7 +280,9 @@ def _rename_ray_dataframe(df: rd.Dataset, columns: Dict[str, Any]) -> rd.Dataset
273
280
  if len(missing) > 0:
274
281
  raise FugueDataFrameOperationError("found nonexistent columns: {missing}")
275
282
  new_cols = [columns.get(name, name) for name in cols]
276
- return df.map_batches(lambda b: b.rename_columns(new_cols), batch_format="pyarrow")
283
+ return df.map_batches(
284
+ lambda b: b.rename_columns(new_cols), batch_format="pyarrow", **_ZERO_COPY
285
+ )
277
286
 
278
287
 
279
288
  def _get_arrow_tables(df: rd.Dataset) -> Iterable[pa.Table]:
@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union
2
2
 
3
3
  import pyarrow as pa
4
4
  import ray
5
- from duckdb import DuckDBPyConnection, DuckDBPyRelation
5
+ from duckdb import DuckDBPyConnection
6
6
  from triad import Schema, assert_or_throw, to_uuid
7
7
  from triad.utils.threading import RunOnce
8
8
 
@@ -15,14 +15,14 @@ from fugue import (
15
15
  PartitionCursor,
16
16
  PartitionSpec,
17
17
  )
18
- from fugue.constants import KEYWORD_ROWCOUNT
18
+ from fugue.constants import KEYWORD_PARALLELISM, KEYWORD_ROWCOUNT
19
19
  from fugue.dataframe.arrow_dataframe import _build_empty_arrow
20
20
  from fugue_duckdb.dataframe import DuckDataFrame
21
21
  from fugue_duckdb.execution_engine import DuckExecutionEngine
22
22
 
23
23
  from ._constants import FUGUE_RAY_DEFAULT_BATCH_SIZE, FUGUE_RAY_ZERO_COPY
24
24
  from ._utils.cluster import get_default_partitions, get_default_shuffle_partitions
25
- from ._utils.dataframe import add_partition_key
25
+ from ._utils.dataframe import add_coarse_partition_key, add_partition_key
26
26
  from ._utils.io import RayIO
27
27
  from .dataframe import RayDataFrame
28
28
 
@@ -72,12 +72,14 @@ class RayMapEngine(MapEngine):
72
72
  partition_spec: PartitionSpec,
73
73
  on_init: Optional[Callable[[int, DataFrame], Any]] = None,
74
74
  ) -> DataFrame:
75
- presort = partition_spec.presort
75
+ output_schema = Schema(output_schema)
76
+ input_schema = df.schema
77
+ presort = partition_spec.get_sorts(
78
+ input_schema, with_partition_keys=partition_spec.algo == "coarse"
79
+ )
76
80
  presort_tuples = [
77
81
  (k, "ascending" if v else "descending") for k, v in presort.items()
78
82
  ]
79
- output_schema = Schema(output_schema)
80
- input_schema = df.schema
81
83
  cursor = partition_spec.get_cursor(input_schema, 0)
82
84
  on_init_once: Any = (
83
85
  None
@@ -91,7 +93,7 @@ class RayMapEngine(MapEngine):
91
93
  if adf.shape[0] == 0:
92
94
  return _build_empty_arrow(output_schema)
93
95
  adf = adf.remove_column(len(input_schema)) # remove partition key
94
- if len(presort_tuples) > 0:
96
+ if len(partition_spec.presort) > 0:
95
97
  if pa.__version__ < "7": # pragma: no cover
96
98
  idx = pa.compute.sort_indices(
97
99
  adf, options=pa.compute.SortOptions(presort_tuples)
@@ -118,12 +120,20 @@ class RayMapEngine(MapEngine):
118
120
  _df = self.execution_engine.repartition( # type: ignore
119
121
  _df, PartitionSpec(num=n)
120
122
  )
121
- rdf, _ = add_partition_key(
122
- _df.native,
123
- keys=partition_spec.partition_by,
124
- input_schema=input_schema,
125
- output_key=_RAY_PARTITION_KEY,
126
- )
123
+ if partition_spec.algo != "coarse":
124
+ rdf, _ = add_partition_key(
125
+ _df.native,
126
+ keys=partition_spec.partition_by,
127
+ input_schema=input_schema,
128
+ output_key=_RAY_PARTITION_KEY,
129
+ )
130
+ else:
131
+ rdf = add_coarse_partition_key(
132
+ _df.native,
133
+ keys=partition_spec.partition_by,
134
+ output_key=_RAY_PARTITION_KEY,
135
+ bucket=_df.num_partitions,
136
+ )
127
137
 
128
138
  gdf = rdf.groupby(_RAY_PARTITION_KEY)
129
139
  sdf = gdf.map_groups(
@@ -205,6 +215,7 @@ class RayExecutionEngine(DuckExecutionEngine):
205
215
  ):
206
216
  if not ray.is_initialized(): # pragma: no cover
207
217
  ray.init()
218
+
208
219
  super().__init__(conf, connection)
209
220
  self._io = RayIO(self)
210
221
 
@@ -235,12 +246,15 @@ class RayExecutionEngine(DuckExecutionEngine):
235
246
 
236
247
  rdf = self._to_ray_df(df)
237
248
 
238
- num_funcs = {KEYWORD_ROWCOUNT: lambda: _persist_and_count(rdf)}
249
+ num_funcs = {
250
+ KEYWORD_ROWCOUNT: lambda: _persist_and_count(rdf),
251
+ KEYWORD_PARALLELISM: lambda: self.get_current_parallelism(),
252
+ }
239
253
  num = partition_spec.get_num_partitions(**num_funcs)
240
254
  pdf = rdf.native
241
255
 
242
256
  if num > 0:
243
- if partition_spec.algo in ["hash", "even"]:
257
+ if partition_spec.algo in ["hash", "even", "coarse"]:
244
258
  pdf = pdf.repartition(num)
245
259
  elif partition_spec.algo == "rand":
246
260
  pdf = pdf.repartition(num, shuffle=True)
@@ -267,6 +281,16 @@ class RayExecutionEngine(DuckExecutionEngine):
267
281
  return df if not as_local else df.as_local()
268
282
  return super().convert_yield_dataframe(df, as_local)
269
283
 
284
+ def union(self, df1: DataFrame, df2: DataFrame, distinct: bool = True) -> DataFrame:
285
+ if distinct:
286
+ return super().union(df1, df2, distinct)
287
+ assert_or_throw(
288
+ df1.schema == df2.schema, ValueError(f"{df1.schema} != {df2.schema}")
289
+ )
290
+ tdf1 = self._to_ray_df(df1)
291
+ tdf2 = self._to_ray_df(df2)
292
+ return RayDataFrame(tdf1.native.union(tdf2.native), df1.schema)
293
+
270
294
  def load_df( # type:ignore
271
295
  self,
272
296
  path: Union[str, List[str]],
@@ -315,12 +339,6 @@ class RayExecutionEngine(DuckExecutionEngine):
315
339
  ValueError("schema must be None when df is a DataFrame"),
316
340
  )
317
341
  return df
318
- if isinstance(df, DuckDBPyRelation):
319
- assert_or_throw(
320
- schema is None,
321
- ValueError("schema must be None when df is a DuckDBPyRelation"),
322
- )
323
- return DuckDataFrame(df)
324
342
  return RayDataFrame(df, schema)
325
343
 
326
344
  def _get_remote_args(self) -> Dict[str, Any]:
@@ -1,29 +1,24 @@
1
1
  from typing import Any, Iterable, List, Tuple
2
2
 
3
+ import cloudpickle
4
+ import pandas as pd
3
5
  import pyarrow as pa
4
6
  import pyspark.sql as ps
5
7
  import pyspark.sql.types as pt
6
-
7
- try: # pyspark < 3
8
- from pyspark.sql.types import from_arrow_type, to_arrow_type # type: ignore
9
-
10
- # https://issues.apache.org/jira/browse/SPARK-29041
11
- pt._acceptable_types[pt.BinaryType] = (bytearray, bytes) # type: ignore # pragma: no cover # noqa: E501 # pylint: disable=line-too-long
12
- except ImportError: # pyspark >=3
13
- from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
14
-
15
8
  from pyarrow.types import is_list, is_struct, is_timestamp
9
+ from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
16
10
  from triad.collections import Schema
17
11
  from triad.utils.assertion import assert_arg_not_none, assert_or_throw
18
12
  from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP
19
13
  from triad.utils.schema import quote_name
14
+ from .misc import is_spark_dataframe
20
15
 
21
16
 
22
17
  def to_spark_schema(obj: Any) -> pt.StructType:
23
18
  assert_arg_not_none(obj, "schema")
24
19
  if isinstance(obj, pt.StructType):
25
20
  return obj
26
- if isinstance(obj, ps.DataFrame):
21
+ if is_spark_dataframe(obj):
27
22
  return obj.schema
28
23
  return _from_arrow_schema(Schema(obj).pa_schema)
29
24
 
@@ -32,7 +27,7 @@ def to_schema(obj: Any) -> Schema:
32
27
  assert_arg_not_none(obj, "obj")
33
28
  if isinstance(obj, pt.StructType):
34
29
  return Schema(_to_arrow_schema(obj))
35
- if isinstance(obj, ps.DataFrame):
30
+ if is_spark_dataframe(obj):
36
31
  return to_schema(obj.schema)
37
32
  return Schema(obj)
38
33
 
@@ -113,6 +108,22 @@ def to_type_safe_input(rows: Iterable[ps.Row], schema: Schema) -> Iterable[List[
113
108
  yield r
114
109
 
115
110
 
111
+ def to_pandas(df: ps.DataFrame) -> pd.DataFrame:
112
+ if pd.__version__ < "2" or not any(
113
+ isinstance(x.dataType, (pt.TimestampType, pt.TimestampNTZType))
114
+ for x in df.schema.fields
115
+ ):
116
+ return df.toPandas()
117
+
118
+ def serialize(dfs): # pragma: no cover
119
+ for df in dfs:
120
+ data = cloudpickle.dumps(df)
121
+ yield pd.DataFrame([[data]], columns=["data"])
122
+
123
+ sdf = df.mapInPandas(serialize, schema="data binary")
124
+ return pd.concat(cloudpickle.loads(x.data) for x in sdf.collect())
125
+
126
+
116
127
  # TODO: the following function always set nullable to true,
117
128
  # but should we use field.nullable?
118
129
  def _to_arrow_type(dt: pt.DataType) -> pa.DataType:
fugue_spark/_utils/io.py CHANGED
@@ -21,7 +21,6 @@ class SparkIO(object):
21
21
  "csv": self._load_csv,
22
22
  "parquet": self._load_parquet,
23
23
  "json": self._load_json,
24
- "avro": self._load_avro,
25
24
  }
26
25
 
27
26
  def load_df(
@@ -136,15 +135,3 @@ class SparkIO(object):
136
135
  return SparkDataFrame(reader.load(p))[columns]
137
136
  schema = Schema(columns)
138
137
  return SparkDataFrame(reader.load(p)[schema.names], schema)
139
-
140
- def _load_avro(self, p: List[str], columns: Any = None, **kwargs: Any) -> DataFrame:
141
- reader = self._session.read.format(
142
- "avro"
143
- ) # avro is an external data source that has built-in support since spark 2.4
144
- reader.options(**kwargs)
145
- if columns is None:
146
- return SparkDataFrame(reader.load(p))
147
- if isinstance(columns, list): # column names
148
- return SparkDataFrame(reader.load(p))[columns]
149
- schema = Schema(columns)
150
- return SparkDataFrame(reader.load(p)[schema.names], schema)
@@ -0,0 +1,27 @@
1
+ from typing import Any
2
+
3
+ try:
4
+ from pyspark.sql.connect.session import SparkSession as SparkConnectSession
5
+ from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
6
+ except ImportError: # pragma: no cover
7
+ SparkConnectSession = None
8
+ SparkConnectDataFrame = None
9
+ import pyspark.sql as ps
10
+
11
+
12
+ def is_spark_connect(session: Any) -> bool:
13
+ return SparkConnectSession is not None and isinstance(
14
+ session, (SparkConnectSession, SparkConnectDataFrame)
15
+ )
16
+
17
+
18
+ def is_spark_dataframe(df: Any) -> bool:
19
+ return isinstance(df, ps.DataFrame) or (
20
+ SparkConnectDataFrame is not None and isinstance(df, SparkConnectDataFrame)
21
+ )
22
+
23
+
24
+ def is_spark_session(session: Any) -> bool:
25
+ return isinstance(session, ps.SparkSession) or (
26
+ SparkConnectSession is not None and isinstance(session, SparkConnectSession)
27
+ )
@@ -1,11 +1,12 @@
1
- import random
2
1
  from typing import Any, Iterable, List
3
2
 
4
3
  import pyspark.sql as ps
5
- from fugue_spark._utils.convert import to_schema, to_spark_schema
4
+ import pyspark.sql.functions as psf
6
5
  from pyspark import RDD
7
6
  from pyspark.sql import SparkSession
8
- from pyspark.sql.functions import lit
7
+ import warnings
8
+ from .convert import to_schema, to_spark_schema
9
+ from .misc import is_spark_connect
9
10
 
10
11
  _PARTITION_DUMMY_KEY = "__partition_dummy_key__"
11
12
 
@@ -28,16 +29,10 @@ def rand_repartition(
28
29
  if len(cols) > 0 or num <= 1:
29
30
  return hash_repartition(session, df, num, cols)
30
31
 
31
- def _rand(rows: Iterable[Any], n: int) -> Iterable[Any]: # pragma: no cover
32
- for row in rows:
33
- yield (random.randint(0, n - 1), row)
34
-
35
- rdd = (
36
- df.rdd.mapPartitions(lambda r: _rand(r, num))
37
- .partitionBy(num, lambda k: k)
38
- .mapPartitions(_to_rows)
32
+ tdf = df.withColumn(
33
+ _PARTITION_DUMMY_KEY, (psf.rand(0) * psf.lit(2**15 - 1)).cast("long")
39
34
  )
40
- return session.createDataFrame(rdd, df.schema)
35
+ return tdf.repartition(num, _PARTITION_DUMMY_KEY)[df.schema.names]
41
36
 
42
37
 
43
38
  def even_repartition(
@@ -45,6 +40,9 @@ def even_repartition(
45
40
  ) -> ps.DataFrame:
46
41
  if num == 1:
47
42
  return _single_repartition(df)
43
+ if is_spark_connect(session): # pragma: no cover
44
+ warnings.warn("Even repartitioning is not supported by Spark Connect")
45
+ return hash_repartition(session, df, num, cols)
48
46
  if len(cols) == 0:
49
47
  if num == 0:
50
48
  return df
@@ -82,7 +80,7 @@ def even_repartition(
82
80
 
83
81
  def _single_repartition(df: ps.DataFrame) -> ps.DataFrame:
84
82
  return (
85
- df.withColumn(_PARTITION_DUMMY_KEY, lit(0))
83
+ df.withColumn(_PARTITION_DUMMY_KEY, psf.lit(0))
86
84
  .repartition(_PARTITION_DUMMY_KEY)
87
85
  .drop(_PARTITION_DUMMY_KEY)
88
86
  )
@@ -93,11 +91,6 @@ def _to_rows(rdd: Iterable[Any]) -> Iterable[Any]: # pragma: no cover
93
91
  yield item[1]
94
92
 
95
93
 
96
- def _to_rows_with_key(rdd: Iterable[Any]) -> Iterable[Any]: # pragma: no cover
97
- for item in rdd:
98
- yield list(item[1]) + [item[0]]
99
-
100
-
101
94
  def _zipWithIndex(rdd: RDD, to_rows: bool = False) -> RDD:
102
95
  """
103
96
  Modified from
fugue_spark/dataframe.py CHANGED
@@ -30,7 +30,9 @@ from fugue.plugins import (
30
30
  rename,
31
31
  select_columns,
32
32
  )
33
- from fugue_spark._utils.convert import to_cast_expression, to_schema, to_type_safe_input
33
+
34
+ from ._utils.convert import to_cast_expression, to_pandas, to_schema, to_type_safe_input
35
+ from ._utils.misc import is_spark_connect, is_spark_dataframe
34
36
 
35
37
 
36
38
  class SparkDataFrame(DataFrame):
@@ -51,12 +53,12 @@ class SparkDataFrame(DataFrame):
51
53
 
52
54
  def __init__(self, df: Any = None, schema: Any = None): # noqa: C901
53
55
  self._lock = SerializableRLock()
54
- if isinstance(df, ps.DataFrame):
56
+ if is_spark_dataframe(df):
55
57
  if schema is not None:
56
58
  schema = to_schema(schema).assert_not_empty()
57
59
  has_cast, expr = to_cast_expression(df, schema, True)
58
60
  if has_cast:
59
- df = df.selectExpr(*expr)
61
+ df = df.selectExpr(*expr) # type: ignore
60
62
  else:
61
63
  schema = to_schema(df).assert_not_empty()
62
64
  self._native = df
@@ -94,7 +96,7 @@ class SparkDataFrame(DataFrame):
94
96
  data = list(to_type_safe_input(self.native.collect(), self.schema))
95
97
  res: LocalBoundedDataFrame = ArrayDataFrame(data, self.schema)
96
98
  else:
97
- res = PandasDataFrame(self.native.toPandas(), self.schema)
99
+ res = PandasDataFrame(self.as_pandas(), self.schema)
98
100
  if self.has_metadata:
99
101
  res.reset_metadata(self.metadata)
100
102
  return res
@@ -126,7 +128,7 @@ class SparkDataFrame(DataFrame):
126
128
  return SparkDataFrame(self.native[schema.names])
127
129
 
128
130
  def as_pandas(self) -> pd.DataFrame:
129
- return self.native.toPandas()
131
+ return to_pandas(self.native)
130
132
 
131
133
  def rename(self, columns: Dict[str, str]) -> DataFrame:
132
134
  try:
@@ -150,6 +152,9 @@ class SparkDataFrame(DataFrame):
150
152
  def as_array_iterable(
151
153
  self, columns: Optional[List[str]] = None, type_safe: bool = False
152
154
  ) -> Iterable[Any]:
155
+ if is_spark_connect(self.native): # pragma: no cover
156
+ yield from self.as_array(columns, type_safe=type_safe)
157
+ return
153
158
  sdf = self._select_columns(columns)
154
159
  if not type_safe:
155
160
  for row in to_type_safe_input(sdf.native.rdd.toLocalIterator(), sdf.schema):
@@ -182,47 +187,47 @@ class SparkDataFrame(DataFrame):
182
187
  return SparkDataFrame(self.native.select(*columns))
183
188
 
184
189
 
185
- @is_df.candidate(lambda df: isinstance(df, ps.DataFrame))
190
+ @is_df.candidate(lambda df: is_spark_dataframe(df))
186
191
  def _spark_is_df(df: ps.DataFrame) -> bool:
187
192
  return True
188
193
 
189
194
 
190
- @get_num_partitions.candidate(lambda df: isinstance(df, ps.DataFrame))
195
+ @get_num_partitions.candidate(lambda df: is_spark_dataframe(df))
191
196
  def _spark_num_partitions(df: ps.DataFrame) -> int:
192
197
  return df.rdd.getNumPartitions()
193
198
 
194
199
 
195
- @count.candidate(lambda df: isinstance(df, ps.DataFrame))
200
+ @count.candidate(lambda df: is_spark_dataframe(df))
196
201
  def _spark_df_count(df: ps.DataFrame) -> int:
197
202
  return df.count()
198
203
 
199
204
 
200
- @is_bounded.candidate(lambda df: isinstance(df, ps.DataFrame))
205
+ @is_bounded.candidate(lambda df: is_spark_dataframe(df))
201
206
  def _spark_df_is_bounded(df: ps.DataFrame) -> bool:
202
207
  return True
203
208
 
204
209
 
205
- @is_empty.candidate(lambda df: isinstance(df, ps.DataFrame))
210
+ @is_empty.candidate(lambda df: is_spark_dataframe(df))
206
211
  def _spark_df_is_empty(df: ps.DataFrame) -> bool:
207
212
  return df.first() is None
208
213
 
209
214
 
210
- @is_local.candidate(lambda df: isinstance(df, ps.DataFrame))
215
+ @is_local.candidate(lambda df: is_spark_dataframe(df))
211
216
  def _spark_df_is_local(df: ps.DataFrame) -> bool:
212
217
  return False
213
218
 
214
219
 
215
- @as_local_bounded.candidate(lambda df: isinstance(df, ps.DataFrame))
220
+ @as_local_bounded.candidate(lambda df: is_spark_dataframe(df))
216
221
  def _spark_df_as_local(df: ps.DataFrame) -> pd.DataFrame:
217
- return df.toPandas()
222
+ return to_pandas(df)
218
223
 
219
224
 
220
- @get_column_names.candidate(lambda df: isinstance(df, ps.DataFrame))
225
+ @get_column_names.candidate(lambda df: is_spark_dataframe(df))
221
226
  def _get_spark_df_columns(df: ps.DataFrame) -> List[Any]:
222
227
  return df.columns
223
228
 
224
229
 
225
- @rename.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
230
+ @rename.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
226
231
  def _rename_spark_df(
227
232
  df: ps.DataFrame, columns: Dict[str, Any], as_fugue: bool = False
228
233
  ) -> ps.DataFrame:
@@ -232,7 +237,7 @@ def _rename_spark_df(
232
237
  return _adjust_df(_rename_spark_dataframe(df, columns), as_fugue=as_fugue)
233
238
 
234
239
 
235
- @drop_columns.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
240
+ @drop_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
236
241
  def _drop_spark_df_columns(
237
242
  df: ps.DataFrame, columns: List[str], as_fugue: bool = False
238
243
  ) -> Any:
@@ -244,7 +249,7 @@ def _drop_spark_df_columns(
244
249
  return _adjust_df(df[cols], as_fugue=as_fugue)
245
250
 
246
251
 
247
- @select_columns.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
252
+ @select_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
248
253
  def _select_spark_df_columns(
249
254
  df: ps.DataFrame, columns: List[Any], as_fugue: bool = False
250
255
  ) -> Any:
@@ -254,7 +259,7 @@ def _select_spark_df_columns(
254
259
  return _adjust_df(df[columns], as_fugue=as_fugue)
255
260
 
256
261
 
257
- @head.candidate(lambda df, *args, **kwargs: isinstance(df, ps.DataFrame))
262
+ @head.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
258
263
  def _spark_df_head(
259
264
  df: ps.DataFrame,
260
265
  n: int,
@@ -264,7 +269,7 @@ def _spark_df_head(
264
269
  if columns is not None:
265
270
  df = df[columns]
266
271
  res = df.limit(n)
267
- return SparkDataFrame(res).as_local() if as_fugue else res.toPandas()
272
+ return SparkDataFrame(res).as_local() if as_fugue else to_pandas(res)
268
273
 
269
274
 
270
275
  def _rename_spark_dataframe(df: ps.DataFrame, names: Dict[str, Any]) -> ps.DataFrame: