streamlit-nightly 1.36.1.dev20240630__py2.py3-none-any.whl → 1.36.1.dev20240703__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. streamlit/commands/navigation.py +2 -2
  2. streamlit/components/v1/component_arrow.py +16 -11
  3. streamlit/components/v1/custom_component.py +2 -1
  4. streamlit/config.py +1 -136
  5. streamlit/dataframe_util.py +835 -0
  6. streamlit/delta_generator.py +5 -3
  7. streamlit/elements/arrow.py +17 -13
  8. streamlit/elements/dialog_decorator.py +1 -1
  9. streamlit/elements/exception.py +2 -8
  10. streamlit/elements/image.py +2 -1
  11. streamlit/elements/lib/built_in_chart_utils.py +78 -12
  12. streamlit/elements/lib/column_config_utils.py +1 -1
  13. streamlit/elements/lib/pandas_styler_utils.py +2 -2
  14. streamlit/elements/lib/policies.py +20 -2
  15. streamlit/elements/lib/utils.py +100 -10
  16. streamlit/elements/map.py +2 -2
  17. streamlit/elements/media.py +1 -1
  18. streamlit/elements/metric.py +5 -2
  19. streamlit/elements/plotly_chart.py +1 -1
  20. streamlit/elements/pyplot.py +26 -39
  21. streamlit/elements/vega_charts.py +6 -5
  22. streamlit/elements/widgets/button.py +1 -1
  23. streamlit/elements/widgets/camera_input.py +7 -2
  24. streamlit/elements/widgets/chat.py +1 -1
  25. streamlit/elements/widgets/checkbox.py +7 -2
  26. streamlit/elements/widgets/color_picker.py +7 -2
  27. streamlit/elements/widgets/data_editor.py +10 -9
  28. streamlit/elements/widgets/file_uploader.py +7 -2
  29. streamlit/elements/widgets/multiselect.py +6 -7
  30. streamlit/elements/widgets/number_input.py +7 -2
  31. streamlit/elements/widgets/radio.py +6 -7
  32. streamlit/elements/widgets/select_slider.py +6 -7
  33. streamlit/elements/widgets/selectbox.py +6 -7
  34. streamlit/elements/widgets/slider.py +7 -2
  35. streamlit/elements/widgets/text_widgets.py +8 -5
  36. streamlit/elements/widgets/time_widgets.py +7 -2
  37. streamlit/elements/write.py +5 -5
  38. streamlit/errors.py +0 -29
  39. streamlit/navigation/page.py +8 -3
  40. streamlit/proto/NewSession_pb2.pyi +1 -1
  41. streamlit/runtime/app_session.py +0 -4
  42. streamlit/runtime/caching/cache_utils.py +1 -1
  43. streamlit/runtime/scriptrunner/script_runner.py +7 -22
  44. streamlit/runtime/state/common.py +51 -2
  45. streamlit/runtime/state/session_state.py +2 -1
  46. streamlit/runtime/state/session_state_proxy.py +1 -1
  47. streamlit/runtime/state/widgets.py +1 -1
  48. streamlit/static/asset-manifest.json +2 -2
  49. streamlit/static/index.html +1 -1
  50. streamlit/static/static/js/main.28e3c6e9.js +2 -0
  51. streamlit/testing/v1/element_tree.py +3 -3
  52. streamlit/type_util.py +0 -1069
  53. streamlit/watcher/path_watcher.py +1 -2
  54. {streamlit_nightly-1.36.1.dev20240630.dist-info → streamlit_nightly-1.36.1.dev20240703.dist-info}/METADATA +1 -1
  55. {streamlit_nightly-1.36.1.dev20240630.dist-info → streamlit_nightly-1.36.1.dev20240703.dist-info}/RECORD +60 -59
  56. {streamlit_nightly-1.36.1.dev20240630.dist-info → streamlit_nightly-1.36.1.dev20240703.dist-info}/WHEEL +1 -1
  57. streamlit/static/static/js/main.0326e951.js +0 -2
  58. /streamlit/static/static/js/{main.0326e951.js.LICENSE.txt → main.28e3c6e9.js.LICENSE.txt} +0 -0
  59. {streamlit_nightly-1.36.1.dev20240630.data → streamlit_nightly-1.36.1.dev20240703.data}/scripts/streamlit.cmd +0 -0
  60. {streamlit_nightly-1.36.1.dev20240630.dist-info → streamlit_nightly-1.36.1.dev20240703.dist-info}/entry_points.txt +0 -0
  61. {streamlit_nightly-1.36.1.dev20240630.dist-info → streamlit_nightly-1.36.1.dev20240703.dist-info}/top_level.txt +0 -0
streamlit/type_util.py CHANGED
@@ -16,12 +16,8 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import contextlib
20
- import copy
21
- import math
22
19
  import re
23
20
  import types
24
- from enum import Enum, EnumMeta, auto
25
21
  from typing import (
26
22
  TYPE_CHECKING,
27
23
  Any,
@@ -31,128 +27,30 @@ from typing import (
31
27
  NamedTuple,
32
28
  Protocol,
33
29
  Sequence,
34
- Tuple,
35
30
  TypeVar,
36
31
  Union,
37
- cast,
38
- get_args,
39
32
  overload,
40
33
  )
41
34
 
42
35
  from typing_extensions import TypeAlias, TypeGuard
43
36
 
44
- import streamlit as st
45
- from streamlit import config, errors, logger, string_util
46
37
  from streamlit.errors import StreamlitAPIException
47
38
 
48
39
  if TYPE_CHECKING:
49
40
  import graphviz
50
- import numpy as np
51
- import pyarrow as pa
52
41
  import sympy
53
- from pandas import DataFrame, Index, Series
54
- from pandas.core.indexing import _iLocIndexer
55
- from pandas.io.formats.style import Styler
56
- from pandas.io.formats.style_renderer import StyleRenderer
57
42
  from plotly.graph_objs import Figure
58
43
  from pydeck import Deck
59
44
 
60
45
  from streamlit.runtime.secrets import Secrets
61
46
 
62
-
63
- # Maximum number of rows to request from an unevaluated (out-of-core) dataframe
64
- MAX_UNEVALUATED_DF_ROWS = 10000
65
-
66
- _LOGGER = logger.get_logger(__name__)
67
-
68
- # The array value field names are part of the larger set of possible value
69
- # field names. See the explanation for said set below. The message types
70
- # associated with these fields are distinguished by storing data in a `data`
71
- # field in their messages, meaning they need special treatment in certain
72
- # circumstances. Hence, they need their own, dedicated, sub-type.
73
- ArrayValueFieldName: TypeAlias = Literal[
74
- "double_array_value",
75
- "int_array_value",
76
- "string_array_value",
77
- ]
78
-
79
- # A frozenset containing the allowed values of the ArrayValueFieldName type.
80
- # Useful for membership checking.
81
- ARRAY_VALUE_FIELD_NAMES: Final = frozenset(
82
- cast(
83
- "tuple[ArrayValueFieldName, ...]",
84
- # NOTE: get_args is not recursive, so this only works as long as
85
- # ArrayValueFieldName remains flat.
86
- get_args(ArrayValueFieldName),
87
- )
88
- )
89
-
90
- # These are the possible field names that can be set in the `value` oneof-field
91
- # of the WidgetState message (schema found in .proto/WidgetStates.proto).
92
- # We need these as a literal type to ensure correspondence with the protobuf
93
- # schema in certain parts of the python code.
94
- # TODO(harahu): It would be preferable if this type was automatically derived
95
- # from the protobuf schema, rather than manually maintained. Not sure how to
96
- # achieve that, though.
97
- ValueFieldName: TypeAlias = Literal[
98
- ArrayValueFieldName,
99
- "arrow_value",
100
- "bool_value",
101
- "bytes_value",
102
- "double_value",
103
- "file_uploader_state_value",
104
- "int_value",
105
- "json_value",
106
- "string_value",
107
- "trigger_value",
108
- "string_trigger_value",
109
- ]
110
-
111
- V_co = TypeVar(
112
- "V_co",
113
- covariant=True, # https://peps.python.org/pep-0484/#covariance-and-contravariance
114
- )
115
-
116
47
  T = TypeVar("T")
117
48
 
118
49
 
119
- class DataFrameGenericAlias(Protocol[V_co]):
120
- """Technically not a GenericAlias, but serves the same purpose in
121
- OptionSequence below, in that it is a type which admits DataFrame,
122
- but is generic. This allows OptionSequence to be a fully generic type,
123
- significantly increasing its usefulness.
124
-
125
- We can't use types.GenericAlias, as it is only available from python>=3.9,
126
- and isn't easily back-ported.
127
- """
128
-
129
- @property
130
- def iloc(self) -> _iLocIndexer: ...
131
-
132
-
133
- OptionSequence: TypeAlias = Union[
134
- Iterable[V_co],
135
- DataFrameGenericAlias[V_co],
136
- ]
137
-
138
-
139
- Key: TypeAlias = Union[str, int]
140
-
141
- LabelVisibility = Literal["visible", "hidden", "collapsed"]
142
-
143
- VegaLiteType = Literal["quantitative", "ordinal", "temporal", "nominal"]
144
-
145
- ChartStackType = Literal["normalize", "center", "layered"]
146
-
147
-
148
50
  class SupportsStr(Protocol):
149
51
  def __str__(self) -> str: ...
150
52
 
151
53
 
152
- def is_array_value_field_name(obj: object) -> TypeGuard[ArrayValueFieldName]:
153
- return obj in ARRAY_VALUE_FIELD_NAMES
154
-
155
-
156
54
  @overload
157
55
  def is_type(
158
56
  obj: object, fqn_type_pattern: Literal["pydeck.bindings.deck.Deck"]
@@ -206,54 +104,6 @@ def get_fqn_type(obj: object) -> str:
206
104
  return get_fqn(type(obj))
207
105
 
208
106
 
209
- _PANDAS_DF_TYPE_STR: Final = "pandas.core.frame.DataFrame"
210
- _PANDAS_INDEX_TYPE_STR: Final = "pandas.core.indexes.base.Index"
211
- _PANDAS_SERIES_TYPE_STR: Final = "pandas.core.series.Series"
212
- _PANDAS_STYLER_TYPE_STR: Final = "pandas.io.formats.style.Styler"
213
- _NUMPY_ARRAY_TYPE_STR: Final = "numpy.ndarray"
214
- _SNOWPARK_DF_TYPE_STR: Final = "snowflake.snowpark.dataframe.DataFrame"
215
- _SNOWPARK_DF_ROW_TYPE_STR: Final = "snowflake.snowpark.row.Row"
216
- _SNOWPARK_TABLE_TYPE_STR: Final = "snowflake.snowpark.table.Table"
217
- _PYSPARK_DF_TYPE_STR: Final = "pyspark.sql.dataframe.DataFrame"
218
- _MODIN_DF_TYPE_STR: Final = "modin.pandas.dataframe.DataFrame"
219
- _MODIN_SERIES_TYPE_STR: Final = "modin.pandas.series.Series"
220
- _SNOWPANDAS_DF_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.dataframe.DataFrame"
221
- _SNOWPANDAS_SERIES_TYPE_STR: Final = "snowflake.snowpark.modin.pandas.series.Series"
222
-
223
-
224
- _DATAFRAME_LIKE_TYPES: Final[tuple[str, ...]] = (
225
- _PANDAS_DF_TYPE_STR,
226
- _PANDAS_INDEX_TYPE_STR,
227
- _PANDAS_SERIES_TYPE_STR,
228
- _PANDAS_STYLER_TYPE_STR,
229
- _NUMPY_ARRAY_TYPE_STR,
230
- )
231
-
232
- # We show a special "UnevaluatedDataFrame" warning for cached funcs
233
- # that attempt to return one of these unserializable types:
234
- UNEVALUATED_DATAFRAME_TYPES = (
235
- _MODIN_DF_TYPE_STR,
236
- _MODIN_SERIES_TYPE_STR,
237
- _PYSPARK_DF_TYPE_STR,
238
- _SNOWPANDAS_DF_TYPE_STR,
239
- _SNOWPANDAS_SERIES_TYPE_STR,
240
- _SNOWPARK_DF_TYPE_STR,
241
- _SNOWPARK_TABLE_TYPE_STR,
242
- )
243
-
244
- DataFrameLike: TypeAlias = "Union[DataFrame, Index, Series, Styler]"
245
-
246
- _DATAFRAME_COMPATIBLE_TYPES: Final[tuple[type, ...]] = (
247
- dict,
248
- list,
249
- set,
250
- tuple,
251
- type(None),
252
- )
253
-
254
- _DataFrameCompatible: TypeAlias = Union[dict, list, set, Tuple[Any], None]
255
- DataFrameCompatible: TypeAlias = Union[_DataFrameCompatible, DataFrameLike]
256
-
257
107
  _BYTES_LIKE_TYPES: Final[tuple[type, ...]] = (
258
108
  bytes,
259
109
  bytearray,
@@ -262,104 +112,6 @@ _BYTES_LIKE_TYPES: Final[tuple[type, ...]] = (
262
112
  BytesLike: TypeAlias = Union[bytes, bytearray]
263
113
 
264
114
 
265
- class DataFormat(Enum):
266
- """DataFormat is used to determine the format of the data."""
267
-
268
- UNKNOWN = auto()
269
- EMPTY = auto() # None
270
- PANDAS_DATAFRAME = auto() # pd.DataFrame
271
- PANDAS_SERIES = auto() # pd.Series
272
- PANDAS_INDEX = auto() # pd.Index
273
- NUMPY_LIST = auto() # np.array[Scalar]
274
- NUMPY_MATRIX = auto() # np.array[List[Scalar]]
275
- PYARROW_TABLE = auto() # pyarrow.Table
276
- SNOWPARK_OBJECT = auto() # Snowpark DataFrame, Table, List[Row]
277
- PYSPARK_OBJECT = auto() # pyspark.DataFrame
278
- MODIN_OBJECT = auto() # Modin DataFrame, Series
279
- SNOWPANDAS_OBJECT = auto() # Snowpandas DataFrame, Series
280
- PANDAS_STYLER = auto() # pandas Styler
281
- LIST_OF_RECORDS = auto() # List[Dict[str, Scalar]]
282
- LIST_OF_ROWS = auto() # List[List[Scalar]]
283
- LIST_OF_VALUES = auto() # List[Scalar]
284
- TUPLE_OF_VALUES = auto() # Tuple[Scalar]
285
- SET_OF_VALUES = auto() # Set[Scalar]
286
- COLUMN_INDEX_MAPPING = auto() # {column: {index: value}}
287
- COLUMN_VALUE_MAPPING = auto() # {column: List[values]}
288
- COLUMN_SERIES_MAPPING = auto() # {column: Series(values)}
289
- KEY_VALUE_DICT = auto() # {index: value}
290
-
291
-
292
- def is_dataframe(obj: object) -> TypeGuard[DataFrame]:
293
- return is_type(obj, _PANDAS_DF_TYPE_STR)
294
-
295
-
296
- def is_dataframe_like(obj: object) -> TypeGuard[DataFrameLike]:
297
- return any(is_type(obj, t) for t in _DATAFRAME_LIKE_TYPES)
298
-
299
-
300
- def is_unevaluated_data_object(obj: object) -> bool:
301
- """True if the object is one of the supported unevaluated data objects:
302
-
303
- Currently supported objects are:
304
- - Snowpark DataFrame / Table
305
- - PySpark DataFrame
306
- - Modin DataFrame / Series
307
- - Snowpandas DataFrame / Series
308
-
309
- Unevaluated means that the data is not yet in the local memory.
310
- Unevaluated data objects are treated differently from other data objects by only
311
- requesting a subset of the data instead of loading all data into th memory
312
- """
313
- return (
314
- is_snowpark_data_object(obj)
315
- or is_pyspark_data_object(obj)
316
- or is_snowpandas_data_object(obj)
317
- or is_modin_data_object(obj)
318
- )
319
-
320
-
321
- def is_snowpark_data_object(obj: object) -> bool:
322
- """True if obj is a Snowpark DataFrame or Table."""
323
- return is_type(obj, _SNOWPARK_TABLE_TYPE_STR) or is_type(obj, _SNOWPARK_DF_TYPE_STR)
324
-
325
-
326
- def is_snowpark_row_list(obj: object) -> bool:
327
- """True if obj is a list of snowflake.snowpark.row.Row."""
328
- if not isinstance(obj, list):
329
- return False
330
- if len(obj) < 1:
331
- return False
332
- if not hasattr(obj[0], "__class__"):
333
- return False
334
- return is_type(obj[0], _SNOWPARK_DF_ROW_TYPE_STR)
335
-
336
-
337
- def is_pyspark_data_object(obj: object) -> bool:
338
- """True if obj is of type pyspark.sql.dataframe.DataFrame"""
339
- return (
340
- is_type(obj, _PYSPARK_DF_TYPE_STR)
341
- and hasattr(obj, "toPandas")
342
- and callable(obj.toPandas)
343
- )
344
-
345
-
346
- def is_modin_data_object(obj: object) -> bool:
347
- """True if obj is of Modin Dataframe or Series"""
348
- return is_type(obj, _MODIN_DF_TYPE_STR) or is_type(obj, _MODIN_SERIES_TYPE_STR)
349
-
350
-
351
- def is_snowpandas_data_object(obj: object) -> bool:
352
- """True if obj is a Snowpark Pandas DataFrame or Series."""
353
- return is_type(obj, _SNOWPANDAS_DF_TYPE_STR) or is_type(
354
- obj, _SNOWPANDAS_SERIES_TYPE_STR
355
- )
356
-
357
-
358
- def is_dataframe_compatible(obj: object) -> TypeGuard[DataFrameCompatible]:
359
- """True if type that can be passed to convert_anything_to_df."""
360
- return is_dataframe_like(obj) or type(obj) in _DATAFRAME_COMPATIBLE_TYPES
361
-
362
-
363
115
  def is_bytes_like(obj: object) -> TypeGuard[BytesLike]:
364
116
  """True if the type is considered bytes-like for the purposes of
365
117
  protobuf data marshalling.
@@ -432,15 +184,6 @@ def is_openai_chunk(obj: object) -> bool:
432
184
  return is_type(obj, _OPENAI_CHUNK_RE)
433
185
 
434
186
 
435
- def is_list_of_scalars(data: Iterable[Any]) -> bool:
436
- """Check if the list only contains scalar values."""
437
- from pandas.api.types import infer_dtype
438
-
439
- # Overview on all value that are interpreted as scalar:
440
- # https://pandas.pydata.org/docs/reference/api/pandas.api.types.is_scalar.html
441
- return infer_dtype(data, skipna=True) not in ["mixed", "unknown-array"]
442
-
443
-
444
187
  def is_plotly_chart(obj: object) -> TypeGuard[Figure | list[Any] | dict[str, Any]]:
445
188
  """True if input looks like a Plotly chart."""
446
189
  return (
@@ -514,10 +257,6 @@ def is_namedtuple(x: object) -> TypeGuard[NamedTuple]:
514
257
  return all(type(n).__name__ == "str" for n in f)
515
258
 
516
259
 
517
- def is_pandas_styler(obj: object) -> TypeGuard[Styler]:
518
- return is_type(obj, _PANDAS_STYLER_TYPE_STR)
519
-
520
-
521
260
  def is_pydeck(obj: object) -> TypeGuard[Deck]:
522
261
  """True if input looks like a pydeck chart."""
523
262
  return is_type(obj, "pydeck.bindings.deck.Deck")
@@ -549,208 +288,6 @@ def is_sequence(seq: Any) -> bool:
549
288
  return True
550
289
 
551
290
 
552
- @overload
553
- def convert_anything_to_df(
554
- data: Any,
555
- max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
556
- ensure_copy: bool = False,
557
- ) -> DataFrame: ...
558
-
559
-
560
- @overload
561
- def convert_anything_to_df(
562
- data: Any,
563
- max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
564
- ensure_copy: bool = False,
565
- allow_styler: bool = False,
566
- ) -> DataFrame | Styler: ...
567
-
568
-
569
- def convert_anything_to_df(
570
- data: Any,
571
- max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
572
- ensure_copy: bool = False,
573
- allow_styler: bool = False,
574
- ) -> DataFrame | Styler:
575
- """Try to convert different formats to a Pandas Dataframe.
576
-
577
- Parameters
578
- ----------
579
- data : ndarray, Iterable, dict, DataFrame, Styler, pa.Table, None, dict, list, or any
580
-
581
- max_unevaluated_rows: int
582
- If unevaluated data is detected this func will evaluate it,
583
- taking max_unevaluated_rows, defaults to 10k and 100 for st.table
584
-
585
- ensure_copy: bool
586
- If True, make sure to always return a copy of the data. If False, it depends on the
587
- type of the data. For example, a Pandas DataFrame will be returned as-is.
588
-
589
- allow_styler: bool
590
- If True, allows this to return a Pandas Styler object as well. If False, returns
591
- a plain Pandas DataFrame (which, of course, won't contain the Styler's styles).
592
-
593
- Returns
594
- -------
595
- pandas.DataFrame or pandas.Styler
596
-
597
- """
598
- import pandas as pd
599
-
600
- if is_type(data, _PANDAS_DF_TYPE_STR):
601
- return data.copy() if ensure_copy else cast(pd.DataFrame, data)
602
-
603
- if is_pandas_styler(data):
604
- # Every Styler is a StyleRenderer. I'm casting to StyleRenderer here rather than to the more
605
- # correct Styler becayse MyPy doesn't like when we cast to Styler. It complains .data
606
- # doesn't exist, when it does in fact exist in the parent class StyleRenderer!
607
- sr = cast("StyleRenderer", data)
608
-
609
- if allow_styler:
610
- if ensure_copy:
611
- out = copy.deepcopy(sr)
612
- out.data = sr.data.copy()
613
- return cast("Styler", out)
614
- else:
615
- return data
616
- else:
617
- return cast("Styler", sr.data.copy() if ensure_copy else sr.data)
618
-
619
- if is_type(data, "numpy.ndarray"):
620
- if len(data.shape) == 0:
621
- return pd.DataFrame([])
622
- return pd.DataFrame(data)
623
-
624
- if is_modin_data_object(data):
625
- data = data.head(max_unevaluated_rows)._to_pandas()
626
-
627
- if isinstance(data, pd.Series):
628
- data = data.to_frame()
629
-
630
- if data.shape[0] == max_unevaluated_rows:
631
- st.caption(
632
- f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
633
- "Call `_to_pandas()` on the dataframe to show more."
634
- )
635
- return cast(pd.DataFrame, data)
636
-
637
- if is_pyspark_data_object(data):
638
- data = data.limit(max_unevaluated_rows).toPandas()
639
- if data.shape[0] == max_unevaluated_rows:
640
- st.caption(
641
- f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
642
- "Call `toPandas()` on the dataframe to show more."
643
- )
644
- return cast(pd.DataFrame, data)
645
-
646
- if is_snowpark_data_object(data):
647
- data = data.limit(max_unevaluated_rows).to_pandas()
648
- if data.shape[0] == max_unevaluated_rows:
649
- st.caption(
650
- f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
651
- "Call `to_pandas()` on the dataframe to show more."
652
- )
653
- return cast(pd.DataFrame, data)
654
-
655
- if is_snowpandas_data_object(data):
656
- data = data.head(max_unevaluated_rows).to_pandas()
657
-
658
- if isinstance(data, pd.Series):
659
- data = data.to_frame()
660
-
661
- if data.shape[0] == max_unevaluated_rows:
662
- st.caption(
663
- f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
664
- "Call `to_pandas()` on the dataframe to show more."
665
- )
666
- return cast(pd.DataFrame, data)
667
-
668
- # This is inefficient when data is a pyarrow.Table as it will be converted
669
- # back to Arrow when marshalled to protobuf, but area/bar/line charts need
670
- # DataFrame magic to generate the correct output.
671
- if hasattr(data, "to_pandas"):
672
- return cast(pd.DataFrame, data.to_pandas())
673
-
674
- # Try to convert to pandas.DataFrame. This will raise an error is df is not
675
- # compatible with the pandas.DataFrame constructor.
676
- try:
677
- return pd.DataFrame(data)
678
-
679
- except ValueError as ex:
680
- if isinstance(data, dict):
681
- with contextlib.suppress(ValueError):
682
- # Try to use index orient as back-up to support key-value dicts
683
- return pd.DataFrame.from_dict(data, orient="index")
684
- raise errors.StreamlitAPIException(
685
- f"""
686
- Unable to convert object of type `{type(data)}` to `pandas.DataFrame`.
687
- Offending object:
688
- ```py
689
- {data}
690
- ```"""
691
- ) from ex
692
-
693
-
694
- @overload
695
- def ensure_iterable(obj: Iterable[V_co]) -> Iterable[V_co]: ...
696
-
697
-
698
- @overload
699
- def ensure_iterable(obj: OptionSequence[V_co]) -> Iterable[Any]: ...
700
-
701
-
702
- def ensure_iterable(obj: OptionSequence[V_co] | Iterable[V_co]) -> Iterable[Any]:
703
- """Try to convert different formats to something iterable. Most inputs
704
- are assumed to be iterable, but if we have a DataFrame, we can just
705
- select the first column to iterate over. If the input is not iterable,
706
- a TypeError is raised.
707
-
708
- Parameters
709
- ----------
710
- obj : list, tuple, numpy.ndarray, pandas.Series, pandas.DataFrame, pyspark.sql.DataFrame, snowflake.snowpark.dataframe.DataFrame or snowflake.snowpark.table.Table
711
-
712
- Returns
713
- -------
714
- iterable
715
-
716
- """
717
-
718
- if is_unevaluated_data_object(obj):
719
- obj = convert_anything_to_df(obj)
720
-
721
- if is_dataframe(obj):
722
- # Return first column as a pd.Series
723
- # The type of the elements in this column is not known up front, hence
724
- # the Iterable[Any] return type.
725
- return cast(Iterable[Any], obj.iloc[:, 0])
726
-
727
- if is_iterable(obj):
728
- return obj
729
-
730
- raise TypeError(
731
- f"Object is not an iterable and could not be converted to one. Object: {obj}"
732
- )
733
-
734
-
735
- def ensure_indexable(obj: OptionSequence[V_co]) -> Sequence[V_co]:
736
- """Try to ensure a value is an indexable Sequence. If the collection already
737
- is one, it has the index method that we need. Otherwise, convert it to a list.
738
- """
739
- it = ensure_iterable(obj)
740
- # This is an imperfect check because there is no guarantee that an `index`
741
- # function actually does the thing we want.
742
- index_fn = getattr(it, "index", None)
743
- if callable(index_fn) and type(it) != EnumMeta:
744
- # We return a shallow copy of the Sequence here because the return value of
745
- # this function is saved in a widget serde class instance to be used in later
746
- # script runs, and we don't want mutations to the options object passed to a
747
- # widget affect the widget.
748
- # (See https://github.com/streamlit/streamlit/issues/7534)
749
- return copy.copy(cast(Sequence[V_co], it))
750
- else:
751
- return list(it)
752
-
753
-
754
291
  def check_python_comparable(seq: Sequence[Any]) -> None:
755
292
  """Check if the sequence elements support "python comparison".
756
293
  That means that the equality operator (==) returns a boolean value.
@@ -852,609 +389,3 @@ def is_version_less_than(v1: str, v2: str) -> bool:
852
389
  from packaging import version
853
390
 
854
391
  return version.parse(v1) < version.parse(v2)
855
-
856
-
857
- def _maybe_truncate_table(
858
- table: pa.Table, truncated_rows: int | None = None
859
- ) -> pa.Table:
860
- """Experimental feature to automatically truncate tables that
861
- are larger than the maximum allowed message size. It needs to be enabled
862
- via the server.enableArrowTruncation config option.
863
-
864
- Parameters
865
- ----------
866
- table : pyarrow.Table
867
- A table to truncate.
868
-
869
- truncated_rows : int or None
870
- The number of rows that have been truncated so far. This is used by
871
- the recursion logic to keep track of the total number of truncated
872
- rows.
873
-
874
- """
875
-
876
- if config.get_option("server.enableArrowTruncation"):
877
- # This is an optimization problem: We don't know at what row
878
- # the perfect cut-off is to comply with the max size. But we want to figure
879
- # it out in as few iterations as possible. We almost always will cut out
880
- # more than required to keep the iterations low.
881
-
882
- # The maximum size allowed for protobuf messages in bytes:
883
- max_message_size = int(config.get_option("server.maxMessageSize") * 1e6)
884
- # We add 1 MB for other overhead related to the protobuf message.
885
- # This is a very conservative estimate, but it should be good enough.
886
- table_size = int(table.nbytes + 1 * 1e6)
887
- table_rows = table.num_rows
888
-
889
- if table_rows > 1 and table_size > max_message_size:
890
- # targeted rows == the number of rows the table should be truncated to.
891
- # Calculate an approximation of how many rows we need to truncate to.
892
- targeted_rows = math.ceil(table_rows * (max_message_size / table_size))
893
- # Make sure to cut out at least a couple of rows to avoid running
894
- # this logic too often since it is quite inefficient and could lead
895
- # to infinity recursions without these precautions.
896
- targeted_rows = math.floor(
897
- max(
898
- min(
899
- # Cut out:
900
- # an additional 5% of the estimated num rows to cut out:
901
- targeted_rows - math.floor((table_rows - targeted_rows) * 0.05),
902
- # at least 1% of table size:
903
- table_rows - (table_rows * 0.01),
904
- # at least 5 rows:
905
- table_rows - 5,
906
- ),
907
- 1, # but it should always have at least 1 row
908
- )
909
- )
910
- sliced_table = table.slice(0, targeted_rows)
911
- return _maybe_truncate_table(
912
- sliced_table, (truncated_rows or 0) + (table_rows - targeted_rows)
913
- )
914
-
915
- if truncated_rows:
916
- displayed_rows = string_util.simplify_number(table.num_rows)
917
- total_rows = string_util.simplify_number(table.num_rows + truncated_rows)
918
-
919
- if displayed_rows == total_rows:
920
- # If the simplified numbers are the same,
921
- # we just display the exact numbers.
922
- displayed_rows = str(table.num_rows)
923
- total_rows = str(table.num_rows + truncated_rows)
924
-
925
- st.caption(
926
- f"⚠️ Showing {displayed_rows} out of {total_rows} "
927
- "rows due to data size limitations."
928
- )
929
-
930
- return table
931
-
932
-
933
- def pyarrow_table_to_bytes(table: pa.Table) -> bytes:
934
- """Serialize pyarrow.Table to bytes using Apache Arrow.
935
-
936
- Parameters
937
- ----------
938
- table : pyarrow.Table
939
- A table to convert.
940
-
941
- """
942
- try:
943
- table = _maybe_truncate_table(table)
944
- except RecursionError as err:
945
- # This is a very unlikely edge case, but we want to make sure that
946
- # it doesn't lead to unexpected behavior.
947
- # If there is a recursion error, we just return the table as-is
948
- # which will lead to the normal message limit exceed error.
949
- _LOGGER.warning(
950
- "Recursion error while truncating Arrow table. This is not "
951
- "supposed to happen.",
952
- exc_info=err,
953
- )
954
-
955
- import pyarrow as pa
956
-
957
- # Convert table to bytes
958
- sink = pa.BufferOutputStream()
959
- writer = pa.RecordBatchStreamWriter(sink, table.schema)
960
- writer.write_table(table)
961
- writer.close()
962
- return cast(bytes, sink.getvalue().to_pybytes())
963
-
964
-
965
- def is_colum_type_arrow_incompatible(column: Series[Any] | Index) -> bool:
966
- """Return True if the column type is known to cause issues during Arrow conversion."""
967
- from pandas.api.types import infer_dtype, is_dict_like, is_list_like
968
-
969
- if column.dtype.kind in [
970
- "c", # complex64, complex128, complex256
971
- ]:
972
- return True
973
-
974
- if str(column.dtype) in {
975
- # These period types are not yet supported by our frontend impl.
976
- # See comments in Quiver.ts for more details.
977
- "period[B]",
978
- "period[N]",
979
- "period[ns]",
980
- "period[U]",
981
- "period[us]",
982
- }:
983
- return True
984
-
985
- if column.dtype == "object":
986
- # The dtype of mixed type columns is always object, the actual type of the column
987
- # values can be determined via the infer_dtype function:
988
- # https://pandas.pydata.org/docs/reference/api/pandas.api.types.infer_dtype.html
989
- inferred_type = infer_dtype(column, skipna=True)
990
-
991
- if inferred_type in [
992
- "mixed-integer",
993
- "complex",
994
- ]:
995
- return True
996
- elif inferred_type == "mixed":
997
- # This includes most of the more complex/custom types (objects, dicts, lists, ...)
998
- if len(column) == 0 or not hasattr(column, "iloc"):
999
- # The column seems to be invalid, so we assume it is incompatible.
1000
- # But this would most likely never happen since empty columns
1001
- # cannot be mixed.
1002
- return True
1003
-
1004
- # Get the first value to check if it is a supported list-like type.
1005
- first_value = column.iloc[0]
1006
-
1007
- if (
1008
- not is_list_like(first_value)
1009
- # dicts are list-like, but have issues in Arrow JS (see comments in Quiver.ts)
1010
- or is_dict_like(first_value)
1011
- # Frozensets are list-like, but are not compatible with pyarrow.
1012
- or isinstance(first_value, frozenset)
1013
- ):
1014
- # This seems to be an incompatible list-like type
1015
- return True
1016
- return False
1017
- # We did not detect an incompatible type, so we assume it is compatible:
1018
- return False
1019
-
1020
-
1021
- def fix_arrow_incompatible_column_types(
1022
- df: DataFrame, selected_columns: list[str] | None = None
1023
- ) -> DataFrame:
1024
- """Fix column types that are not supported by Arrow table.
1025
-
1026
- This includes mixed types (e.g. mix of integers and strings)
1027
- as well as complex numbers (complex128 type). These types will cause
1028
- errors during conversion of the dataframe to an Arrow table.
1029
- It is fixed by converting all values of the column to strings
1030
- This is sufficient for displaying the data on the frontend.
1031
-
1032
- Parameters
1033
- ----------
1034
- df : pandas.DataFrame
1035
- A dataframe to fix.
1036
-
1037
- selected_columns: List[str] or None
1038
- A list of columns to fix. If None, all columns are evaluated.
1039
-
1040
- Returns
1041
- -------
1042
- The fixed dataframe.
1043
- """
1044
- import pandas as pd
1045
-
1046
- # Make a copy, but only initialize if necessary to preserve memory.
1047
- df_copy: DataFrame | None = None
1048
- for col in selected_columns or df.columns:
1049
- if is_colum_type_arrow_incompatible(df[col]):
1050
- if df_copy is None:
1051
- df_copy = df.copy()
1052
- df_copy[col] = df[col].astype("string")
1053
-
1054
- # The index can also contain mixed types
1055
- # causing Arrow issues during conversion.
1056
- # Skipping multi-indices since they won't return
1057
- # the correct value from infer_dtype
1058
- if not selected_columns and (
1059
- not isinstance(
1060
- df.index,
1061
- pd.MultiIndex,
1062
- )
1063
- and is_colum_type_arrow_incompatible(df.index)
1064
- ):
1065
- if df_copy is None:
1066
- df_copy = df.copy()
1067
- df_copy.index = df.index.astype("string")
1068
- return df_copy if df_copy is not None else df
1069
-
1070
-
1071
- def data_frame_to_bytes(df: DataFrame) -> bytes:
1072
- """Serialize pandas.DataFrame to bytes using Apache Arrow.
1073
-
1074
- Parameters
1075
- ----------
1076
- df : pandas.DataFrame
1077
- A dataframe to convert.
1078
-
1079
- """
1080
- import pyarrow as pa
1081
-
1082
- try:
1083
- table = pa.Table.from_pandas(df)
1084
- except (pa.ArrowTypeError, pa.ArrowInvalid, pa.ArrowNotImplementedError) as ex:
1085
- _LOGGER.info(
1086
- "Serialization of dataframe to Arrow table was unsuccessful due to: %s. "
1087
- "Applying automatic fixes for column types to make the dataframe Arrow-compatible.",
1088
- ex,
1089
- )
1090
- df = fix_arrow_incompatible_column_types(df)
1091
- table = pa.Table.from_pandas(df)
1092
- return pyarrow_table_to_bytes(table)
1093
-
1094
-
1095
- def bytes_to_data_frame(source: bytes) -> DataFrame:
1096
- """Convert bytes to pandas.DataFrame.
1097
-
1098
- Using this function in production needs to make sure that
1099
- the pyarrow version >= 14.0.1.
1100
-
1101
- Parameters
1102
- ----------
1103
- source : bytes
1104
- A bytes object to convert.
1105
-
1106
- """
1107
- import pyarrow as pa
1108
-
1109
- reader = pa.RecordBatchStreamReader(source)
1110
- return reader.read_pandas()
1111
-
1112
-
1113
- def determine_data_format(input_data: Any) -> DataFormat:
1114
- """Determine the data format of the input data.
1115
-
1116
- Parameters
1117
- ----------
1118
- input_data : Any
1119
- The input data to determine the data format of.
1120
-
1121
- Returns
1122
- -------
1123
- DataFormat
1124
- The data format of the input data.
1125
- """
1126
- import numpy as np
1127
- import pandas as pd
1128
- import pyarrow as pa
1129
-
1130
- if input_data is None:
1131
- return DataFormat.EMPTY
1132
- elif isinstance(input_data, pd.DataFrame):
1133
- return DataFormat.PANDAS_DATAFRAME
1134
- elif isinstance(input_data, np.ndarray):
1135
- if len(input_data.shape) == 1:
1136
- # For technical reasons, we need to distinguish one
1137
- # one-dimensional numpy array from multidimensional ones.
1138
- return DataFormat.NUMPY_LIST
1139
- return DataFormat.NUMPY_MATRIX
1140
- elif isinstance(input_data, pa.Table):
1141
- return DataFormat.PYARROW_TABLE
1142
- elif isinstance(input_data, pd.Series):
1143
- return DataFormat.PANDAS_SERIES
1144
- elif isinstance(input_data, pd.Index):
1145
- return DataFormat.PANDAS_INDEX
1146
- elif is_pandas_styler(input_data):
1147
- return DataFormat.PANDAS_STYLER
1148
- elif is_snowpark_data_object(input_data):
1149
- return DataFormat.SNOWPARK_OBJECT
1150
- elif is_modin_data_object(input_data):
1151
- return DataFormat.MODIN_OBJECT
1152
- elif is_snowpandas_data_object(input_data):
1153
- return DataFormat.SNOWPANDAS_OBJECT
1154
- elif is_pyspark_data_object(input_data):
1155
- return DataFormat.PYSPARK_OBJECT
1156
- elif isinstance(input_data, (list, tuple, set)):
1157
- if is_list_of_scalars(input_data):
1158
- # -> one-dimensional data structure
1159
- if isinstance(input_data, tuple):
1160
- return DataFormat.TUPLE_OF_VALUES
1161
- if isinstance(input_data, set):
1162
- return DataFormat.SET_OF_VALUES
1163
- return DataFormat.LIST_OF_VALUES
1164
- else:
1165
- # -> Multi-dimensional data structure
1166
- # This should always contain at least one element,
1167
- # otherwise the values type from infer_dtype would have been empty
1168
- first_element = next(iter(input_data))
1169
- if isinstance(first_element, dict):
1170
- return DataFormat.LIST_OF_RECORDS
1171
- if isinstance(first_element, (list, tuple, set)):
1172
- return DataFormat.LIST_OF_ROWS
1173
- elif isinstance(input_data, dict):
1174
- if not input_data:
1175
- return DataFormat.KEY_VALUE_DICT
1176
- if len(input_data) > 0:
1177
- first_value = next(iter(input_data.values()))
1178
- if isinstance(first_value, dict):
1179
- return DataFormat.COLUMN_INDEX_MAPPING
1180
- if isinstance(first_value, (list, tuple)):
1181
- return DataFormat.COLUMN_VALUE_MAPPING
1182
- if isinstance(first_value, pd.Series):
1183
- return DataFormat.COLUMN_SERIES_MAPPING
1184
- # In the future, we could potentially also support the tight & split formats here
1185
- if is_list_of_scalars(input_data.values()):
1186
- # Only use the key-value dict format if the values are only scalar values
1187
- return DataFormat.KEY_VALUE_DICT
1188
- return DataFormat.UNKNOWN
1189
-
1190
-
1191
- def _unify_missing_values(df: DataFrame) -> DataFrame:
1192
- """Unify all missing values in a DataFrame to None.
1193
-
1194
- Pandas uses a variety of values to represent missing values, including np.nan,
1195
- NaT, None, and pd.NA. This function replaces all of these values with None,
1196
- which is the only missing value type that is supported by all data
1197
- """
1198
- import numpy as np
1199
-
1200
- return df.fillna(np.nan).replace([np.nan], [None])
1201
-
1202
-
1203
- def convert_df_to_data_format(
1204
- df: DataFrame, data_format: DataFormat
1205
- ) -> (
1206
- DataFrame
1207
- | Series[Any]
1208
- | pa.Table
1209
- | np.ndarray[Any, np.dtype[Any]]
1210
- | tuple[Any]
1211
- | list[Any]
1212
- | set[Any]
1213
- | dict[str, Any]
1214
- ):
1215
- """Convert a dataframe to the specified data format.
1216
-
1217
- Parameters
1218
- ----------
1219
- df : pd.DataFrame
1220
- The dataframe to convert.
1221
-
1222
- data_format : DataFormat
1223
- The data format to convert to.
1224
-
1225
- Returns
1226
- -------
1227
- pd.DataFrame, pd.Series, pyarrow.Table, np.ndarray, list, set, tuple, or dict.
1228
- The converted dataframe.
1229
- """
1230
-
1231
- if data_format in [
1232
- DataFormat.EMPTY,
1233
- DataFormat.PANDAS_DATAFRAME,
1234
- DataFormat.SNOWPARK_OBJECT,
1235
- DataFormat.PYSPARK_OBJECT,
1236
- DataFormat.PANDAS_INDEX,
1237
- DataFormat.PANDAS_STYLER,
1238
- DataFormat.MODIN_OBJECT,
1239
- DataFormat.SNOWPANDAS_OBJECT,
1240
- ]:
1241
- return df
1242
- elif data_format == DataFormat.NUMPY_LIST:
1243
- import numpy as np
1244
-
1245
- # It's a 1-dimensional array, so we only return
1246
- # the first column as numpy array
1247
- # Calling to_numpy() on the full DataFrame would result in:
1248
- # [[1], [2]] instead of [1, 2]
1249
- return np.ndarray(0) if df.empty else df.iloc[:, 0].to_numpy()
1250
- elif data_format == DataFormat.NUMPY_MATRIX:
1251
- import numpy as np
1252
-
1253
- return np.ndarray(0) if df.empty else df.to_numpy()
1254
- elif data_format == DataFormat.PYARROW_TABLE:
1255
- import pyarrow as pa
1256
-
1257
- return pa.Table.from_pandas(df)
1258
- elif data_format == DataFormat.PANDAS_SERIES:
1259
- # Select first column in dataframe and create a new series based on the values
1260
- if len(df.columns) != 1:
1261
- raise ValueError(
1262
- f"DataFrame is expected to have a single column but has {len(df.columns)}."
1263
- )
1264
- return df[df.columns[0]]
1265
- elif data_format == DataFormat.LIST_OF_RECORDS:
1266
- return _unify_missing_values(df).to_dict(orient="records")
1267
- elif data_format == DataFormat.LIST_OF_ROWS:
1268
- # to_numpy converts the dataframe to a list of rows
1269
- return _unify_missing_values(df).to_numpy().tolist()
1270
- elif data_format == DataFormat.COLUMN_INDEX_MAPPING:
1271
- return _unify_missing_values(df).to_dict(orient="dict")
1272
- elif data_format == DataFormat.COLUMN_VALUE_MAPPING:
1273
- return _unify_missing_values(df).to_dict(orient="list")
1274
- elif data_format == DataFormat.COLUMN_SERIES_MAPPING:
1275
- return df.to_dict(orient="series")
1276
- elif data_format in [
1277
- DataFormat.LIST_OF_VALUES,
1278
- DataFormat.TUPLE_OF_VALUES,
1279
- DataFormat.SET_OF_VALUES,
1280
- ]:
1281
- df = _unify_missing_values(df)
1282
- return_list = []
1283
- if len(df.columns) == 1:
1284
- # Get the first column and convert to list
1285
- return_list = df[df.columns[0]].tolist()
1286
- elif len(df.columns) >= 1:
1287
- raise ValueError(
1288
- f"DataFrame is expected to have a single column but has {len(df.columns)}."
1289
- )
1290
- if data_format == DataFormat.TUPLE_OF_VALUES:
1291
- return tuple(return_list)
1292
- if data_format == DataFormat.SET_OF_VALUES:
1293
- return set(return_list)
1294
- return return_list
1295
- elif data_format == DataFormat.KEY_VALUE_DICT:
1296
- df = _unify_missing_values(df)
1297
- # The key is expected to be the index -> this will return the first column
1298
- # as a dict with index as key.
1299
- return {} if df.empty else df.iloc[:, 0].to_dict()
1300
-
1301
- raise ValueError(f"Unsupported input data format: {data_format}")
1302
-
1303
-
1304
- @overload
1305
- def to_key(key: None) -> None: ...
1306
-
1307
-
1308
- @overload
1309
- def to_key(key: Key) -> str: ...
1310
-
1311
-
1312
- def to_key(key: Key | None) -> str | None:
1313
- if key is None:
1314
- return None
1315
- else:
1316
- return str(key)
1317
-
1318
-
1319
- def maybe_tuple_to_list(item: Any) -> Any:
1320
- """Convert a tuple to a list. Leave as is if it's not a tuple."""
1321
- return list(item) if isinstance(item, tuple) else item
1322
-
1323
-
1324
- def maybe_raise_label_warnings(label: str | None, label_visibility: str | None):
1325
- if not label:
1326
- _LOGGER.warning(
1327
- "`label` got an empty value. This is discouraged for accessibility "
1328
- "reasons and may be disallowed in the future by raising an exception. "
1329
- "Please provide a non-empty label and hide it with label_visibility "
1330
- "if needed."
1331
- )
1332
- if label_visibility not in ("visible", "hidden", "collapsed"):
1333
- raise errors.StreamlitAPIException(
1334
- f"Unsupported label_visibility option '{label_visibility}'. "
1335
- f"Valid values are 'visible', 'hidden' or 'collapsed'."
1336
- )
1337
-
1338
-
1339
- # The code below is copied from Altair, and slightly modified.
1340
- # We copy this code here so we don't depend on private Altair functions.
1341
- # Source: https://github.com/altair-viz/altair/blob/62ca5e37776f5cecb27e83c1fbd5d685a173095d/altair/utils/core.py#L193
1342
-
1343
-
1344
- # STREAMLIT MOD: I changed the type for the data argument from "pd.Series" to Series,
1345
- # and the return type to a Union including a (str, list) tuple, since the function does
1346
- # return that in some situations.
1347
- def infer_vegalite_type(
1348
- data: Series[Any],
1349
- ) -> VegaLiteType:
1350
- """
1351
- From an array-like input, infer the correct vega typecode
1352
- ('ordinal', 'nominal', 'quantitative', or 'temporal')
1353
-
1354
- Parameters
1355
- ----------
1356
- data: Numpy array or Pandas Series
1357
- """
1358
- from pandas.api.types import infer_dtype
1359
-
1360
- # STREAMLIT MOD: I'm using infer_dtype directly here, rather than using Altair's wrapper. Their
1361
- # wrapper is only there to support Pandas < 0.20, but Streamlit requires Pandas 1.3.
1362
- typ = infer_dtype(data)
1363
-
1364
- if typ in [
1365
- "floating",
1366
- "mixed-integer-float",
1367
- "integer",
1368
- "mixed-integer",
1369
- "complex",
1370
- ]:
1371
- return "quantitative"
1372
-
1373
- elif typ == "categorical" and data.cat.ordered:
1374
- # STREAMLIT MOD: The original code returns a tuple here:
1375
- # return ("ordinal", data.cat.categories.tolist())
1376
- # But returning the tuple here isn't compatible with our
1377
- # built-in chart implementation. And it also doesn't seem to be necessary.
1378
- # Altair already extracts the correct sort order somewhere else.
1379
- # More info about the issue here: https://github.com/streamlit/streamlit/issues/7776
1380
- return "ordinal"
1381
- elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
1382
- return "nominal"
1383
- elif typ in [
1384
- "datetime",
1385
- "datetime64",
1386
- "timedelta",
1387
- "timedelta64",
1388
- "date",
1389
- "time",
1390
- "period",
1391
- ]:
1392
- return "temporal"
1393
- else:
1394
- # STREAMLIT MOD: I commented this out since Streamlit doesn't have a warnings object.
1395
- # warnings.warn(
1396
- # "I don't know how to infer vegalite type from '{}'. "
1397
- # "Defaulting to nominal.".format(typ),
1398
- # stacklevel=1,
1399
- # )
1400
- return "nominal"
1401
-
1402
-
1403
- E1 = TypeVar("E1", bound=Enum)
1404
- E2 = TypeVar("E2", bound=Enum)
1405
-
1406
- ALLOWED_ENUM_COERCION_CONFIG_SETTINGS = ("off", "nameOnly", "nameAndValue")
1407
-
1408
-
1409
- def coerce_enum(from_enum_value: E1, to_enum_class: type[E2]) -> E1 | E2:
1410
- """Attempt to coerce an Enum value to another EnumMeta.
1411
-
1412
- An Enum value of EnumMeta E1 is considered coercable to EnumType E2
1413
- if the EnumMeta __qualname__ match and the names of their members
1414
- match as well. (This is configurable in streamlist configs)
1415
- """
1416
- if not isinstance(from_enum_value, Enum):
1417
- raise ValueError(
1418
- f"Expected an Enum in the first argument. Got {type(from_enum_value)}"
1419
- )
1420
- if not isinstance(to_enum_class, EnumMeta):
1421
- raise ValueError(
1422
- f"Expected an EnumMeta/Type in the second argument. Got {type(to_enum_class)}"
1423
- )
1424
- if isinstance(from_enum_value, to_enum_class):
1425
- return from_enum_value # Enum is already a member, no coersion necessary
1426
-
1427
- coercion_type = config.get_option("runner.enumCoercion")
1428
- if coercion_type not in ALLOWED_ENUM_COERCION_CONFIG_SETTINGS:
1429
- raise errors.StreamlitAPIException(
1430
- "Invalid value for config option runner.enumCoercion. "
1431
- f"Expected one of {ALLOWED_ENUM_COERCION_CONFIG_SETTINGS}, "
1432
- f"but got '{coercion_type}'."
1433
- )
1434
- if coercion_type == "off":
1435
- return from_enum_value # do not attempt to coerce
1436
-
1437
- # We now know this is an Enum AND the user has configured coercion enabled.
1438
- # Check if we do NOT meet the required conditions and log a failure message
1439
- # if that is the case.
1440
- from_enum_class = from_enum_value.__class__
1441
- if (
1442
- from_enum_class.__qualname__ != to_enum_class.__qualname__
1443
- or (
1444
- coercion_type == "nameOnly"
1445
- and set(to_enum_class._member_names_) != set(from_enum_class._member_names_)
1446
- )
1447
- or (
1448
- coercion_type == "nameAndValue"
1449
- and set(to_enum_class._value2member_map_)
1450
- != set(from_enum_class._value2member_map_)
1451
- )
1452
- ):
1453
- _LOGGER.debug("Failed to coerce %s to class %s", from_enum_value, to_enum_class)
1454
- return from_enum_value # do not attempt to coerce
1455
-
1456
- # At this point we think the Enum is coercable, and we know
1457
- # E1 and E2 have the same member names. We convert from E1 to E2 using _name_
1458
- # (since user Enum subclasses can override the .name property in 3.11)
1459
- _LOGGER.debug("Coerced %s to class %s", from_enum_value, to_enum_class)
1460
- return to_enum_class[from_enum_value._name_]