streamlit-nightly 1.36.1.dev20240702__py2.py3-none-any.whl → 1.36.1.dev20240704__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. streamlit/commands/navigation.py +1 -1
  2. streamlit/components/v1/component_arrow.py +16 -11
  3. streamlit/components/v1/custom_component.py +2 -1
  4. streamlit/dataframe_util.py +884 -0
  5. streamlit/delta_generator.py +6 -4
  6. streamlit/elements/arrow.py +26 -45
  7. streamlit/elements/lib/built_in_chart_utils.py +78 -19
  8. streamlit/elements/lib/column_config_utils.py +1 -1
  9. streamlit/elements/lib/pandas_styler_utils.py +4 -2
  10. streamlit/elements/lib/policies.py +60 -8
  11. streamlit/elements/lib/utils.py +100 -10
  12. streamlit/elements/map.py +4 -15
  13. streamlit/elements/metric.py +5 -2
  14. streamlit/elements/plotly_chart.py +11 -12
  15. streamlit/elements/vega_charts.py +19 -31
  16. streamlit/elements/widgets/button.py +17 -15
  17. streamlit/elements/widgets/camera_input.py +15 -10
  18. streamlit/elements/widgets/chat.py +9 -11
  19. streamlit/elements/widgets/checkbox.py +13 -11
  20. streamlit/elements/widgets/color_picker.py +14 -10
  21. streamlit/elements/widgets/data_editor.py +18 -19
  22. streamlit/elements/widgets/file_uploader.py +15 -10
  23. streamlit/elements/widgets/multiselect.py +13 -15
  24. streamlit/elements/widgets/number_input.py +13 -11
  25. streamlit/elements/widgets/radio.py +13 -15
  26. streamlit/elements/widgets/select_slider.py +13 -13
  27. streamlit/elements/widgets/selectbox.py +13 -15
  28. streamlit/elements/widgets/slider.py +14 -10
  29. streamlit/elements/widgets/text_widgets.py +21 -17
  30. streamlit/elements/widgets/time_widgets.py +18 -16
  31. streamlit/elements/write.py +7 -15
  32. streamlit/runtime/caching/cache_utils.py +2 -5
  33. streamlit/runtime/state/common.py +51 -2
  34. streamlit/runtime/state/session_state.py +2 -1
  35. streamlit/runtime/state/session_state_proxy.py +1 -1
  36. streamlit/runtime/state/widgets.py +1 -1
  37. streamlit/static/asset-manifest.json +2 -2
  38. streamlit/static/index.html +1 -1
  39. streamlit/static/static/js/{main.e2ab315a.js → main.28e3c6e9.js} +2 -2
  40. streamlit/testing/v1/element_tree.py +3 -3
  41. streamlit/type_util.py +0 -1069
  42. {streamlit_nightly-1.36.1.dev20240702.dist-info → streamlit_nightly-1.36.1.dev20240704.dist-info}/METADATA +1 -1
  43. {streamlit_nightly-1.36.1.dev20240702.dist-info → streamlit_nightly-1.36.1.dev20240704.dist-info}/RECORD +48 -47
  44. /streamlit/static/static/js/{main.e2ab315a.js.LICENSE.txt → main.28e3c6e9.js.LICENSE.txt} +0 -0
  45. {streamlit_nightly-1.36.1.dev20240702.data → streamlit_nightly-1.36.1.dev20240704.data}/scripts/streamlit.cmd +0 -0
  46. {streamlit_nightly-1.36.1.dev20240702.dist-info → streamlit_nightly-1.36.1.dev20240704.dist-info}/WHEEL +0 -0
  47. {streamlit_nightly-1.36.1.dev20240702.dist-info → streamlit_nightly-1.36.1.dev20240704.dist-info}/entry_points.txt +0 -0
  48. {streamlit_nightly-1.36.1.dev20240702.dist-info → streamlit_nightly-1.36.1.dev20240704.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,884 @@
1
+ # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A bunch of useful utilities for dealing with dataframes."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import contextlib
20
+ import math
21
+ from enum import Enum, EnumMeta, auto
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Dict,
26
+ Final,
27
+ Iterable,
28
+ Protocol,
29
+ Sequence,
30
+ TypeVar,
31
+ Union,
32
+ cast,
33
+ )
34
+
35
+ from typing_extensions import TypeAlias, TypeGuard
36
+
37
+ import streamlit as st
38
+ from streamlit import config, errors, logger, string_util
39
+ from streamlit.type_util import is_type
40
+
41
+ if TYPE_CHECKING:
42
+ import numpy as np
43
+ import pyarrow as pa
44
+ from pandas import DataFrame, Index, Series
45
+ from pandas.core.indexing import _iLocIndexer
46
+ from pandas.io.formats.style import Styler
47
+
48
+ _LOGGER: Final = logger.get_logger(__name__)
49
+
50
+
51
+ # Maximum number of rows to request from an unevaluated (out-of-core) dataframe
52
+ _MAX_UNEVALUATED_DF_ROWS = 10000
53
+
54
+ _PANDAS_STYLER_TYPE_STR: Final = "pandas.io.formats.style.Styler"
55
+ _SNOWPARK_DF_TYPE_STR: Final = "snowflake.snowpark.dataframe.DataFrame"
56
+ _SNOWPARK_DF_ROW_TYPE_STR: Final = "snowflake.snowpark.row.Row"
57
+ _SNOWPARK_TABLE_TYPE_STR: Final = "snowflake.snowpark.table.Table"
58
+ _PYSPARK_DF_TYPE_STR: Final = "pyspark.sql.dataframe.DataFrame"
59
+ _MODIN_DF_TYPE_STR: Final = "modin.pandas.dataframe.DataFrame"
60
+ _MODIN_SERIES_TYPE_STR: Final = "modin.pandas.series.Series"
61
+ _SNOWPANDAS_DF_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.dataframe.DataFrame"
62
+ _SNOWPANDAS_SERIES_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.series.Series"
63
+
64
+
65
+ V_co = TypeVar(
66
+ "V_co",
67
+ covariant=True, # https://peps.python.org/pep-0484/#covariance-and-contravariance
68
+ )
69
+
70
+
71
+ class DataFrameGenericAlias(Protocol[V_co]):
72
+ """Technically not a GenericAlias, but serves the same purpose in
73
+ OptionSequence below, in that it is a type which admits DataFrame,
74
+ but is generic. This allows OptionSequence to be a fully generic type,
75
+ significantly increasing its usefulness.
76
+
77
+ We can't use types.GenericAlias, as it is only available from python>=3.9,
78
+ and isn't easily back-ported.
79
+ """
80
+
81
+ @property
82
+ def iloc(self) -> _iLocIndexer: ...
83
+
84
+
85
+ OptionSequence: TypeAlias = Union[
86
+ Iterable[V_co],
87
+ DataFrameGenericAlias[V_co],
88
+ ]
89
+
90
+ # Various data types supported by our dataframe processing
91
+ # used for commands like `st.dataframe`, `st.table`, `st.map`,
92
+ # st.line_chart`...
93
+ Data: TypeAlias = Union[
94
+ "DataFrame",
95
+ "Series",
96
+ "Styler",
97
+ "Index",
98
+ "pa.Table",
99
+ "np.ndarray",
100
+ Iterable[Any],
101
+ Dict[Any, Any],
102
+ None,
103
+ ]
104
+
105
+
106
+ class DataFormat(Enum):
107
+ """DataFormat is used to determine the format of the data."""
108
+
109
+ UNKNOWN = auto()
110
+ EMPTY = auto() # None
111
+ PANDAS_DATAFRAME = auto() # pd.DataFrame
112
+ PANDAS_SERIES = auto() # pd.Series
113
+ PANDAS_INDEX = auto() # pd.Index
114
+ NUMPY_LIST = auto() # np.array[Scalar]
115
+ NUMPY_MATRIX = auto() # np.array[List[Scalar]]
116
+ PYARROW_TABLE = auto() # pyarrow.Table
117
+ SNOWPARK_OBJECT = auto() # Snowpark DataFrame, Table, List[Row]
118
+ PYSPARK_OBJECT = auto() # pyspark.DataFrame
119
+ MODIN_OBJECT = auto() # Modin DataFrame, Series
120
+ SNOWPANDAS_OBJECT = auto() # Snowpandas DataFrame, Series
121
+ PANDAS_STYLER = auto() # pandas Styler
122
+ LIST_OF_RECORDS = auto() # List[Dict[str, Scalar]]
123
+ LIST_OF_ROWS = auto() # List[List[Scalar]]
124
+ LIST_OF_VALUES = auto() # List[Scalar]
125
+ TUPLE_OF_VALUES = auto() # Tuple[Scalar]
126
+ SET_OF_VALUES = auto() # Set[Scalar]
127
+ COLUMN_INDEX_MAPPING = auto() # {column: {index: value}}
128
+ COLUMN_VALUE_MAPPING = auto() # {column: List[values]}
129
+ COLUMN_SERIES_MAPPING = auto() # {column: Series(values)}
130
+ KEY_VALUE_DICT = auto() # {index: value}
131
+
132
+
133
+ def is_dataframe_like(obj: object) -> bool:
134
+ """True if the object is a dataframe-like object.
135
+
136
+ This does not include basic collection types like list, dict, tuple, etc.
137
+ """
138
+
139
+ if obj is None or isinstance(
140
+ obj, (list, tuple, set, dict, str, bytes, int, float, bool)
141
+ ):
142
+ # Basic types are not considered dataframe-like, so we can
143
+ # return False early to avoid unnecessary checks.
144
+ return False
145
+
146
+ return determine_data_format(obj) in [
147
+ DataFormat.PANDAS_DATAFRAME,
148
+ DataFormat.PANDAS_SERIES,
149
+ DataFormat.PANDAS_INDEX,
150
+ DataFormat.PANDAS_STYLER,
151
+ DataFormat.NUMPY_LIST,
152
+ DataFormat.NUMPY_MATRIX,
153
+ DataFormat.PYARROW_TABLE,
154
+ DataFormat.SNOWPARK_OBJECT,
155
+ DataFormat.PYSPARK_OBJECT,
156
+ DataFormat.MODIN_OBJECT,
157
+ DataFormat.SNOWPANDAS_OBJECT,
158
+ ]
159
+
160
+
161
+ def is_unevaluated_data_object(obj: object) -> bool:
162
+ """True if the object is one of the supported unevaluated data objects:
163
+
164
+ Currently supported objects are:
165
+ - Snowpark DataFrame / Table
166
+ - PySpark DataFrame
167
+ - Modin DataFrame / Series
168
+ - Snowpandas DataFrame / Series
169
+
170
+ Unevaluated means that the data is not yet in the local memory.
171
+ Unevaluated data objects are treated differently from other data objects by only
172
+ requesting a subset of the data instead of loading all data into th memory
173
+ """
174
+ return (
175
+ is_snowpark_data_object(obj)
176
+ or is_pyspark_data_object(obj)
177
+ or is_snowpandas_data_object(obj)
178
+ or is_modin_data_object(obj)
179
+ )
180
+
181
+
182
+ def is_snowpark_data_object(obj: object) -> bool:
183
+ """True if obj is a Snowpark DataFrame or Table."""
184
+ return is_type(obj, _SNOWPARK_TABLE_TYPE_STR) or is_type(obj, _SNOWPARK_DF_TYPE_STR)
185
+
186
+
187
+ def is_snowpark_row_list(obj: object) -> bool:
188
+ """True if obj is a list of snowflake.snowpark.row.Row."""
189
+ if not isinstance(obj, list):
190
+ return False
191
+ if len(obj) < 1:
192
+ return False
193
+ if not hasattr(obj[0], "__class__"):
194
+ return False
195
+ return is_type(obj[0], _SNOWPARK_DF_ROW_TYPE_STR)
196
+
197
+
198
+ def is_pyspark_data_object(obj: object) -> bool:
199
+ """True if obj is of type pyspark.sql.dataframe.DataFrame"""
200
+ return (
201
+ is_type(obj, _PYSPARK_DF_TYPE_STR)
202
+ and hasattr(obj, "toPandas")
203
+ and callable(obj.toPandas)
204
+ )
205
+
206
+
207
+ def is_modin_data_object(obj: object) -> bool:
208
+ """True if obj is of Modin Dataframe or Series"""
209
+ return is_type(obj, _MODIN_DF_TYPE_STR) or is_type(obj, _MODIN_SERIES_TYPE_STR)
210
+
211
+
212
+ def is_snowpandas_data_object(obj: object) -> bool:
213
+ """True if obj is a Snowpark Pandas DataFrame or Series."""
214
+ return is_type(obj, _SNOWPANDAS_DF_TYPE_STR) or is_type(
215
+ obj, _SNOWPANDAS_SERIES_TYPE_STR
216
+ )
217
+
218
+
219
+ def is_pandas_styler(obj: object) -> TypeGuard[Styler]:
220
+ """True if obj is a pandas Styler."""
221
+ return is_type(obj, _PANDAS_STYLER_TYPE_STR)
222
+
223
+
224
+ def convert_anything_to_pandas_df(
225
+ data: Any,
226
+ max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
227
+ ensure_copy: bool = False,
228
+ ) -> DataFrame:
229
+ """Try to convert different formats to a Pandas Dataframe.
230
+
231
+ Parameters
232
+ ----------
233
+ data : any
234
+ The data to convert to a Pandas DataFrame.
235
+
236
+ max_unevaluated_rows: int
237
+ If unevaluated data is detected this func will evaluate it,
238
+ taking max_unevaluated_rows, defaults to 10k.
239
+
240
+ ensure_copy: bool
241
+ If True, make sure to always return a copy of the data. If False, it depends on
242
+ the type of the data. For example, a Pandas DataFrame will be returned as-is.
243
+
244
+ Returns
245
+ -------
246
+ pandas.DataFrame
247
+
248
+ """
249
+ import numpy as np
250
+ import pandas as pd
251
+
252
+ if isinstance(data, pd.DataFrame):
253
+ return data.copy() if ensure_copy else cast(pd.DataFrame, data)
254
+
255
+ if isinstance(data, (pd.Series, pd.Index)):
256
+ return pd.DataFrame(data)
257
+
258
+ if is_pandas_styler(data):
259
+ return cast(pd.DataFrame, data.data.copy() if ensure_copy else data.data)
260
+
261
+ if isinstance(data, np.ndarray):
262
+ return pd.DataFrame([]) if len(data.shape) == 0 else pd.DataFrame(data)
263
+
264
+ if is_modin_data_object(data):
265
+ data = data.head(max_unevaluated_rows)._to_pandas()
266
+
267
+ if isinstance(data, pd.Series):
268
+ data = data.to_frame()
269
+
270
+ if data.shape[0] == max_unevaluated_rows:
271
+ st.caption(
272
+ f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
273
+ "rows. Call `_to_pandas()` on the dataframe to show more."
274
+ )
275
+ return cast(pd.DataFrame, data)
276
+
277
+ if is_pyspark_data_object(data):
278
+ data = data.limit(max_unevaluated_rows).toPandas()
279
+ if data.shape[0] == max_unevaluated_rows:
280
+ st.caption(
281
+ f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
282
+ "rows. Call `toPandas()` on the dataframe to show more."
283
+ )
284
+ return cast(pd.DataFrame, data)
285
+
286
+ if is_snowpark_data_object(data):
287
+ data = data.limit(max_unevaluated_rows).to_pandas()
288
+ if data.shape[0] == max_unevaluated_rows:
289
+ st.caption(
290
+ f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
291
+ "rows. Call `to_pandas()` on the dataframe to show more."
292
+ )
293
+ return cast(pd.DataFrame, data)
294
+
295
+ if is_snowpandas_data_object(data):
296
+ data = data.head(max_unevaluated_rows).to_pandas()
297
+
298
+ if isinstance(data, pd.Series):
299
+ data = data.to_frame()
300
+
301
+ if data.shape[0] == max_unevaluated_rows:
302
+ st.caption(
303
+ f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} "
304
+ "rows. Call `to_pandas()` on the dataframe to show more."
305
+ )
306
+ return cast(pd.DataFrame, data)
307
+
308
+ # This is inefficient when data is a pyarrow.Table as it will be converted
309
+ # back to Arrow when marshalled to protobuf, but area/bar/line charts need
310
+ # DataFrame magic to generate the correct output.
311
+ if hasattr(data, "to_pandas"):
312
+ return cast(pd.DataFrame, data.to_pandas())
313
+
314
+ # Try to convert to pandas.DataFrame. This will raise an error is df is not
315
+ # compatible with the pandas.DataFrame constructor.
316
+ try:
317
+ return pd.DataFrame(data)
318
+
319
+ except ValueError as ex:
320
+ if isinstance(data, dict):
321
+ with contextlib.suppress(ValueError):
322
+ # Try to use index orient as back-up to support key-value dicts
323
+ return pd.DataFrame.from_dict(data, orient="index")
324
+ raise errors.StreamlitAPIException(
325
+ f"""
326
+ Unable to convert object of type `{type(data)}` to `pandas.DataFrame`.
327
+ Offending object:
328
+ ```py
329
+ {data}
330
+ ```"""
331
+ ) from ex
332
+
333
+
334
+ def convert_arrow_table_to_arrow_bytes(table: pa.Table) -> bytes:
335
+ """Serialize pyarrow.Table to Arrow IPC bytes.
336
+
337
+ Parameters
338
+ ----------
339
+ table : pyarrow.Table
340
+ A table to convert.
341
+
342
+ Returns
343
+ -------
344
+ bytes
345
+ The serialized Arrow IPC bytes.
346
+ """
347
+ try:
348
+ table = _maybe_truncate_table(table)
349
+ except RecursionError as err:
350
+ # This is a very unlikely edge case, but we want to make sure that
351
+ # it doesn't lead to unexpected behavior.
352
+ # If there is a recursion error, we just return the table as-is
353
+ # which will lead to the normal message limit exceed error.
354
+ _LOGGER.warning(
355
+ "Recursion error while truncating Arrow table. This is not "
356
+ "supposed to happen.",
357
+ exc_info=err,
358
+ )
359
+
360
+ import pyarrow as pa
361
+
362
+ # Convert table to bytes
363
+ sink = pa.BufferOutputStream()
364
+ writer = pa.RecordBatchStreamWriter(sink, table.schema)
365
+ writer.write_table(table)
366
+ writer.close()
367
+ return cast(bytes, sink.getvalue().to_pybytes())
368
+
369
+
370
+ def convert_pandas_df_to_arrow_bytes(df: DataFrame) -> bytes:
371
+ """Serialize pandas.DataFrame to Arrow IPC bytes.
372
+
373
+ Parameters
374
+ ----------
375
+ df : pandas.DataFrame
376
+ A dataframe to convert.
377
+
378
+ Returns
379
+ -------
380
+ bytes
381
+ The serialized Arrow IPC bytes.
382
+ """
383
+ import pyarrow as pa
384
+
385
+ try:
386
+ table = pa.Table.from_pandas(df)
387
+ except (pa.ArrowTypeError, pa.ArrowInvalid, pa.ArrowNotImplementedError) as ex:
388
+ _LOGGER.info(
389
+ "Serialization of dataframe to Arrow table was unsuccessful due to: %s. "
390
+ "Applying automatic fixes for column types to make the dataframe "
391
+ "Arrow-compatible.",
392
+ ex,
393
+ )
394
+ df = fix_arrow_incompatible_column_types(df)
395
+ table = pa.Table.from_pandas(df)
396
+ return convert_arrow_table_to_arrow_bytes(table)
397
+
398
+
399
+ def convert_arrow_bytes_to_pandas_df(source: bytes) -> DataFrame:
400
+ """Convert Arrow bytes (IPC format) to pandas.DataFrame.
401
+
402
+ Using this function in production needs to make sure that
403
+ the pyarrow version >= 14.0.1, because of a critical
404
+ security vulnerability in pyarrow < 14.0.1.
405
+
406
+ Parameters
407
+ ----------
408
+ source : bytes
409
+ A bytes object to convert.
410
+
411
+ Returns
412
+ -------
413
+ pandas.DataFrame
414
+ The converted dataframe.
415
+ """
416
+ import pyarrow as pa
417
+
418
+ reader = pa.RecordBatchStreamReader(source)
419
+ return reader.read_pandas()
420
+
421
+
422
+ def convert_anything_to_arrow_bytes(
423
+ data: Any,
424
+ max_unevaluated_rows: int = _MAX_UNEVALUATED_DF_ROWS,
425
+ ) -> bytes:
426
+ """Try to convert different formats to Arrow IPC format (bytes).
427
+
428
+ This method tries to directly convert the input data to Arrow bytes
429
+ for some supported formats, but falls back to conversion to a Pandas
430
+ DataFrame and then to Arrow bytes.
431
+
432
+ Parameters
433
+ ----------
434
+ data : any
435
+ The data to convert to Arrow bytes.
436
+
437
+ max_unevaluated_rows: int
438
+ If unevaluated data is detected this func will evaluate it,
439
+ taking max_unevaluated_rows, defaults to 10k.
440
+
441
+ Returns
442
+ -------
443
+ bytes
444
+ The serialized Arrow IPC bytes.
445
+ """
446
+
447
+ import pyarrow as pa
448
+
449
+ if isinstance(data, pa.Table):
450
+ return convert_arrow_table_to_arrow_bytes(data)
451
+
452
+ # Fallback: try to convert to pandas DataFrame
453
+ # and then to Arrow bytes
454
+ df = convert_anything_to_pandas_df(data, max_unevaluated_rows)
455
+ return convert_pandas_df_to_arrow_bytes(df)
456
+
457
+
458
+ def convert_anything_to_sequence(obj: OptionSequence[V_co]) -> Sequence[V_co]:
459
+ """Try to convert different formats to an indexable Sequence.
460
+
461
+ If the input is a dataframe-like object, we just select the first
462
+ column to iterate over. If the input cannot be converted to a sequence,
463
+ a TypeError is raised.
464
+
465
+ Parameters
466
+ ----------
467
+ obj : OptionSequence
468
+ The object to convert to a sequence.
469
+
470
+ Returns
471
+ -------
472
+ Sequence
473
+ The converted sequence.
474
+ """
475
+ if obj is None:
476
+ return [] # type: ignore
477
+
478
+ if isinstance(obj, (str, list, tuple, set, range, EnumMeta)):
479
+ # This also ensures that the sequence is copied to prevent
480
+ # potential mutations to the original object.
481
+ return list(obj)
482
+
483
+ if isinstance(obj, dict):
484
+ return list(obj.keys())
485
+
486
+ # Fallback to our DataFrame conversion logic:
487
+ try:
488
+ # We use ensure_copy here because the return value of this function is
489
+ # saved in a widget serde class instance to be used in later script runs,
490
+ # and we don't want mutations to the options object passed to a
491
+ # widget affect the widget.
492
+ # (See https://github.com/streamlit/streamlit/issues/7534)
493
+ data_df = convert_anything_to_pandas_df(obj, ensure_copy=True)
494
+ # Return first column as a list:
495
+ return (
496
+ [] if data_df.empty else cast(Sequence[V_co], data_df.iloc[:, 0].to_list())
497
+ )
498
+ except errors.StreamlitAPIException as e:
499
+ raise TypeError(
500
+ "Object is not an iterable and could not be converted to one. "
501
+ f"Object type: {type(obj)}"
502
+ ) from e
503
+
504
+
505
+ def _maybe_truncate_table(
506
+ table: pa.Table, truncated_rows: int | None = None
507
+ ) -> pa.Table:
508
+ """Experimental feature to automatically truncate tables that
509
+ are larger than the maximum allowed message size. It needs to be enabled
510
+ via the server.enableArrowTruncation config option.
511
+
512
+ Parameters
513
+ ----------
514
+ table : pyarrow.Table
515
+ A table to truncate.
516
+
517
+ truncated_rows : int or None
518
+ The number of rows that have been truncated so far. This is used by
519
+ the recursion logic to keep track of the total number of truncated
520
+ rows.
521
+
522
+ """
523
+
524
+ if config.get_option("server.enableArrowTruncation"):
525
+ # This is an optimization problem: We don't know at what row
526
+ # the perfect cut-off is to comply with the max size. But we want to figure
527
+ # it out in as few iterations as possible. We almost always will cut out
528
+ # more than required to keep the iterations low.
529
+
530
+ # The maximum size allowed for protobuf messages in bytes:
531
+ max_message_size = int(config.get_option("server.maxMessageSize") * 1e6)
532
+ # We add 1 MB for other overhead related to the protobuf message.
533
+ # This is a very conservative estimate, but it should be good enough.
534
+ table_size = int(table.nbytes + 1 * 1e6)
535
+ table_rows = table.num_rows
536
+
537
+ if table_rows > 1 and table_size > max_message_size:
538
+ # targeted rows == the number of rows the table should be truncated to.
539
+ # Calculate an approximation of how many rows we need to truncate to.
540
+ targeted_rows = math.ceil(table_rows * (max_message_size / table_size))
541
+ # Make sure to cut out at least a couple of rows to avoid running
542
+ # this logic too often since it is quite inefficient and could lead
543
+ # to infinity recursions without these precautions.
544
+ targeted_rows = math.floor(
545
+ max(
546
+ min(
547
+ # Cut out:
548
+ # an additional 5% of the estimated num rows to cut out:
549
+ targeted_rows - math.floor((table_rows - targeted_rows) * 0.05),
550
+ # at least 1% of table size:
551
+ table_rows - (table_rows * 0.01),
552
+ # at least 5 rows:
553
+ table_rows - 5,
554
+ ),
555
+ 1, # but it should always have at least 1 row
556
+ )
557
+ )
558
+ sliced_table = table.slice(0, targeted_rows)
559
+ return _maybe_truncate_table(
560
+ sliced_table, (truncated_rows or 0) + (table_rows - targeted_rows)
561
+ )
562
+
563
+ if truncated_rows:
564
+ displayed_rows = string_util.simplify_number(table.num_rows)
565
+ total_rows = string_util.simplify_number(table.num_rows + truncated_rows)
566
+
567
+ if displayed_rows == total_rows:
568
+ # If the simplified numbers are the same,
569
+ # we just display the exact numbers.
570
+ displayed_rows = str(table.num_rows)
571
+ total_rows = str(table.num_rows + truncated_rows)
572
+
573
+ st.caption(
574
+ f"⚠️ Showing {displayed_rows} out of {total_rows} "
575
+ "rows due to data size limitations."
576
+ )
577
+
578
+ return table
579
+
580
+
581
+ def is_colum_type_arrow_incompatible(column: Series[Any] | Index) -> bool:
582
+ """Return True if the column type is known to cause issues during Arrow conversion."""
583
+ from pandas.api.types import infer_dtype, is_dict_like, is_list_like
584
+
585
+ if column.dtype.kind in [
586
+ "c", # complex64, complex128, complex256
587
+ ]:
588
+ return True
589
+
590
+ if str(column.dtype) in {
591
+ # These period types are not yet supported by our frontend impl.
592
+ # See comments in Quiver.ts for more details.
593
+ "period[B]",
594
+ "period[N]",
595
+ "period[ns]",
596
+ "period[U]",
597
+ "period[us]",
598
+ }:
599
+ return True
600
+
601
+ if column.dtype == "object":
602
+ # The dtype of mixed type columns is always object, the actual type of the column
603
+ # values can be determined via the infer_dtype function:
604
+ # https://pandas.pydata.org/docs/reference/api/pandas.api.types.infer_dtype.html
605
+ inferred_type = infer_dtype(column, skipna=True)
606
+
607
+ if inferred_type in [
608
+ "mixed-integer",
609
+ "complex",
610
+ ]:
611
+ return True
612
+ elif inferred_type == "mixed":
613
+ # This includes most of the more complex/custom types (objects, dicts, lists, ...)
614
+ if len(column) == 0 or not hasattr(column, "iloc"):
615
+ # The column seems to be invalid, so we assume it is incompatible.
616
+ # But this would most likely never happen since empty columns
617
+ # cannot be mixed.
618
+ return True
619
+
620
+ # Get the first value to check if it is a supported list-like type.
621
+ first_value = column.iloc[0]
622
+
623
+ if (
624
+ not is_list_like(first_value)
625
+ # dicts are list-like, but have issues in Arrow JS (see comments in Quiver.ts)
626
+ or is_dict_like(first_value)
627
+ # Frozensets are list-like, but are not compatible with pyarrow.
628
+ or isinstance(first_value, frozenset)
629
+ ):
630
+ # This seems to be an incompatible list-like type
631
+ return True
632
+ return False
633
+ # We did not detect an incompatible type, so we assume it is compatible:
634
+ return False
635
+
636
+
637
+ def fix_arrow_incompatible_column_types(
638
+ df: DataFrame, selected_columns: list[str] | None = None
639
+ ) -> DataFrame:
640
+ """Fix column types that are not supported by Arrow table.
641
+
642
+ This includes mixed types (e.g. mix of integers and strings)
643
+ as well as complex numbers (complex128 type). These types will cause
644
+ errors during conversion of the dataframe to an Arrow table.
645
+ It is fixed by converting all values of the column to strings
646
+ This is sufficient for displaying the data on the frontend.
647
+
648
+ Parameters
649
+ ----------
650
+ df : pandas.DataFrame
651
+ A dataframe to fix.
652
+
653
+ selected_columns: List[str] or None
654
+ A list of columns to fix. If None, all columns are evaluated.
655
+
656
+ Returns
657
+ -------
658
+ The fixed dataframe.
659
+ """
660
+ import pandas as pd
661
+
662
+ # Make a copy, but only initialize if necessary to preserve memory.
663
+ df_copy: DataFrame | None = None
664
+ for col in selected_columns or df.columns:
665
+ if is_colum_type_arrow_incompatible(df[col]):
666
+ if df_copy is None:
667
+ df_copy = df.copy()
668
+ df_copy[col] = df[col].astype("string")
669
+
670
+ # The index can also contain mixed types
671
+ # causing Arrow issues during conversion.
672
+ # Skipping multi-indices since they won't return
673
+ # the correct value from infer_dtype
674
+ if not selected_columns and (
675
+ not isinstance(
676
+ df.index,
677
+ pd.MultiIndex,
678
+ )
679
+ and is_colum_type_arrow_incompatible(df.index)
680
+ ):
681
+ if df_copy is None:
682
+ df_copy = df.copy()
683
+ df_copy.index = df.index.astype("string")
684
+ return df_copy if df_copy is not None else df
685
+
686
+
687
+ def _is_list_of_scalars(data: Iterable[Any]) -> bool:
688
+ """Check if the list only contains scalar values."""
689
+ from pandas.api.types import infer_dtype
690
+
691
+ # Overview on all value that are interpreted as scalar:
692
+ # https://pandas.pydata.org/docs/reference/api/pandas.api.types.is_scalar.html
693
+ return infer_dtype(data, skipna=True) not in ["mixed", "unknown-array"]
694
+
695
+
696
+ def determine_data_format(input_data: Any) -> DataFormat:
697
+ """Determine the data format of the input data.
698
+
699
+ Parameters
700
+ ----------
701
+ input_data : Any
702
+ The input data to determine the data format of.
703
+
704
+ Returns
705
+ -------
706
+ DataFormat
707
+ The data format of the input data.
708
+ """
709
+ import numpy as np
710
+ import pandas as pd
711
+ import pyarrow as pa
712
+
713
+ if input_data is None:
714
+ return DataFormat.EMPTY
715
+ elif isinstance(input_data, pd.DataFrame):
716
+ return DataFormat.PANDAS_DATAFRAME
717
+ elif isinstance(input_data, np.ndarray):
718
+ if len(input_data.shape) == 1:
719
+ # For technical reasons, we need to distinguish one
720
+ # one-dimensional numpy array from multidimensional ones.
721
+ return DataFormat.NUMPY_LIST
722
+ return DataFormat.NUMPY_MATRIX
723
+ elif isinstance(input_data, pa.Table):
724
+ return DataFormat.PYARROW_TABLE
725
+ elif isinstance(input_data, pd.Series):
726
+ return DataFormat.PANDAS_SERIES
727
+ elif isinstance(input_data, pd.Index):
728
+ return DataFormat.PANDAS_INDEX
729
+ elif is_pandas_styler(input_data):
730
+ return DataFormat.PANDAS_STYLER
731
+ elif is_snowpark_data_object(input_data):
732
+ return DataFormat.SNOWPARK_OBJECT
733
+ elif is_modin_data_object(input_data):
734
+ return DataFormat.MODIN_OBJECT
735
+ elif is_snowpandas_data_object(input_data):
736
+ return DataFormat.SNOWPANDAS_OBJECT
737
+ elif is_pyspark_data_object(input_data):
738
+ return DataFormat.PYSPARK_OBJECT
739
+ elif isinstance(input_data, (list, tuple, set)):
740
+ if _is_list_of_scalars(input_data):
741
+ # -> one-dimensional data structure
742
+ if isinstance(input_data, tuple):
743
+ return DataFormat.TUPLE_OF_VALUES
744
+ if isinstance(input_data, set):
745
+ return DataFormat.SET_OF_VALUES
746
+ return DataFormat.LIST_OF_VALUES
747
+ else:
748
+ # -> Multi-dimensional data structure
749
+ # This should always contain at least one element,
750
+ # otherwise the values type from infer_dtype would have been empty
751
+ first_element = next(iter(input_data))
752
+ if isinstance(first_element, dict):
753
+ return DataFormat.LIST_OF_RECORDS
754
+ if isinstance(first_element, (list, tuple, set)):
755
+ return DataFormat.LIST_OF_ROWS
756
+ elif isinstance(input_data, dict):
757
+ if not input_data:
758
+ return DataFormat.KEY_VALUE_DICT
759
+ if len(input_data) > 0:
760
+ first_value = next(iter(input_data.values()))
761
+ if isinstance(first_value, dict):
762
+ return DataFormat.COLUMN_INDEX_MAPPING
763
+ if isinstance(first_value, (list, tuple)):
764
+ return DataFormat.COLUMN_VALUE_MAPPING
765
+ if isinstance(first_value, pd.Series):
766
+ return DataFormat.COLUMN_SERIES_MAPPING
767
+ # In the future, we could potentially also support the tight & split formats here
768
+ if _is_list_of_scalars(input_data.values()):
769
+ # Only use the key-value dict format if the values are only scalar values
770
+ return DataFormat.KEY_VALUE_DICT
771
+ return DataFormat.UNKNOWN
772
+
773
+
774
+ def _unify_missing_values(df: DataFrame) -> DataFrame:
775
+ """Unify all missing values in a DataFrame to None.
776
+
777
+ Pandas uses a variety of values to represent missing values, including np.nan,
778
+ NaT, None, and pd.NA. This function replaces all of these values with None,
779
+ which is the only missing value type that is supported by all data
780
+ """
781
+ import numpy as np
782
+
783
+ return df.fillna(np.nan).replace([np.nan], [None])
784
+
785
+
786
+ def convert_pandas_df_to_data_format(
787
+ df: DataFrame, data_format: DataFormat
788
+ ) -> (
789
+ DataFrame
790
+ | Series[Any]
791
+ | pa.Table
792
+ | np.ndarray[Any, np.dtype[Any]]
793
+ | tuple[Any]
794
+ | list[Any]
795
+ | set[Any]
796
+ | dict[str, Any]
797
+ ):
798
+ """Convert a Pandas DataFrame to the specified data format.
799
+
800
+ Parameters
801
+ ----------
802
+ df : pd.DataFrame
803
+ The dataframe to convert.
804
+
805
+ data_format : DataFormat
806
+ The data format to convert to.
807
+
808
+ Returns
809
+ -------
810
+ pd.DataFrame, pd.Series, pyarrow.Table, np.ndarray, list, set, tuple, or dict.
811
+ The converted dataframe.
812
+ """
813
+
814
+ if data_format in [
815
+ DataFormat.EMPTY,
816
+ DataFormat.PANDAS_DATAFRAME,
817
+ DataFormat.SNOWPARK_OBJECT,
818
+ DataFormat.PYSPARK_OBJECT,
819
+ DataFormat.PANDAS_INDEX,
820
+ DataFormat.PANDAS_STYLER,
821
+ DataFormat.MODIN_OBJECT,
822
+ DataFormat.SNOWPANDAS_OBJECT,
823
+ ]:
824
+ return df
825
+ elif data_format == DataFormat.NUMPY_LIST:
826
+ import numpy as np
827
+
828
+ # It's a 1-dimensional array, so we only return
829
+ # the first column as numpy array
830
+ # Calling to_numpy() on the full DataFrame would result in:
831
+ # [[1], [2]] instead of [1, 2]
832
+ return np.ndarray(0) if df.empty else df.iloc[:, 0].to_numpy()
833
+ elif data_format == DataFormat.NUMPY_MATRIX:
834
+ import numpy as np
835
+
836
+ return np.ndarray(0) if df.empty else df.to_numpy()
837
+ elif data_format == DataFormat.PYARROW_TABLE:
838
+ import pyarrow as pa
839
+
840
+ return pa.Table.from_pandas(df)
841
+ elif data_format == DataFormat.PANDAS_SERIES:
842
+ # Select first column in dataframe and create a new series based on the values
843
+ if len(df.columns) != 1:
844
+ raise ValueError(
845
+ f"DataFrame is expected to have a single column but has {len(df.columns)}."
846
+ )
847
+ return df[df.columns[0]]
848
+ elif data_format == DataFormat.LIST_OF_RECORDS:
849
+ return _unify_missing_values(df).to_dict(orient="records")
850
+ elif data_format == DataFormat.LIST_OF_ROWS:
851
+ # to_numpy converts the dataframe to a list of rows
852
+ return _unify_missing_values(df).to_numpy().tolist()
853
+ elif data_format == DataFormat.COLUMN_INDEX_MAPPING:
854
+ return _unify_missing_values(df).to_dict(orient="dict")
855
+ elif data_format == DataFormat.COLUMN_VALUE_MAPPING:
856
+ return _unify_missing_values(df).to_dict(orient="list")
857
+ elif data_format == DataFormat.COLUMN_SERIES_MAPPING:
858
+ return df.to_dict(orient="series")
859
+ elif data_format in [
860
+ DataFormat.LIST_OF_VALUES,
861
+ DataFormat.TUPLE_OF_VALUES,
862
+ DataFormat.SET_OF_VALUES,
863
+ ]:
864
+ df = _unify_missing_values(df)
865
+ return_list = []
866
+ if len(df.columns) == 1:
867
+ # Get the first column and convert to list
868
+ return_list = df[df.columns[0]].tolist()
869
+ elif len(df.columns) >= 1:
870
+ raise ValueError(
871
+ f"DataFrame is expected to have a single column but has {len(df.columns)}."
872
+ )
873
+ if data_format == DataFormat.TUPLE_OF_VALUES:
874
+ return tuple(return_list)
875
+ if data_format == DataFormat.SET_OF_VALUES:
876
+ return set(return_list)
877
+ return return_list
878
+ elif data_format == DataFormat.KEY_VALUE_DICT:
879
+ df = _unify_missing_values(df)
880
+ # The key is expected to be the index -> this will return the first column
881
+ # as a dict with index as key.
882
+ return {} if df.empty else df.iloc[:, 0].to_dict()
883
+
884
+ raise ValueError(f"Unsupported input data format: {data_format}")