chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. chalk/__init__.py +2 -1
  2. chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
  3. chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
  4. chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
  5. chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
  6. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
  7. chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
  8. chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
  9. chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
  10. chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
  11. chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
  12. chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
  13. chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
  14. chalk/_gen/chalk/dataframe/__init__.py +0 -0
  15. chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
  16. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
  17. chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
  18. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
  19. chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
  20. chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
  21. chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
  22. chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
  23. chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
  24. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
  25. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
  26. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
  27. chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
  28. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
  29. chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
  30. chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
  31. chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
  32. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
  33. chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
  34. chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
  35. chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
  36. chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
  37. chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
  38. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
  39. chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
  40. chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
  41. chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
  42. chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
  43. chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
  44. chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
  45. chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
  46. chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
  47. chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
  48. chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
  49. chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
  50. chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
  51. chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
  52. chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
  53. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
  54. chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
  55. chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
  56. chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
  57. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
  58. chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
  59. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
  60. chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
  61. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
  62. chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
  63. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
  64. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
  65. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
  66. chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
  67. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
  68. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
  69. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
  70. chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
  71. chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
  72. chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
  73. chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
  74. chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
  75. chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
  76. chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
  77. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
  78. chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
  79. chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
  80. chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
  81. chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
  82. chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
  83. chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
  84. chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
  85. chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
  86. chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
  87. chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
  88. chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
  89. chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
  90. chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
  91. chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
  92. chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
  93. chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
  94. chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
  95. chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
  96. chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
  97. chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
  98. chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
  99. chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
  100. chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
  101. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
  102. chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
  103. chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
  104. chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
  105. chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
  106. chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
  107. chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
  108. chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
  109. chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
  110. chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
  111. chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
  112. chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
  113. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
  114. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
  115. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
  116. chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
  117. chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
  118. chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
  119. chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
  120. chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
  121. chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
  122. chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
  123. chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
  124. chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
  125. chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
  126. chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
  127. chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
  128. chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
  129. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
  130. chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
  131. chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
  132. chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
  133. chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
  134. chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
  135. chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
  136. chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
  137. chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
  138. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
  139. chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
  140. chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
  141. chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
  142. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
  143. chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
  144. chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
  145. chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
  146. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
  147. chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
  148. chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
  149. chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
  150. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
  151. chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
  152. chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
  153. chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
  154. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
  155. chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
  156. chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
  157. chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
  158. chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
  159. chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
  160. chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
  161. chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
  162. chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
  163. chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
  164. chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
  165. chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
  166. chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
  167. chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
  168. chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
  169. chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
  170. chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
  171. chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
  172. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
  173. chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
  174. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
  175. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
  176. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
  177. chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
  178. chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
  179. chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
  180. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
  181. chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
  182. chalk/_lsp/error_builder.py +11 -0
  183. chalk/_monitoring/Chart.py +1 -3
  184. chalk/_version.py +1 -1
  185. chalk/cli.py +5 -10
  186. chalk/client/client.py +178 -64
  187. chalk/client/client_async.py +154 -0
  188. chalk/client/client_async_impl.py +22 -0
  189. chalk/client/client_grpc.py +738 -112
  190. chalk/client/client_impl.py +541 -136
  191. chalk/client/dataset.py +27 -6
  192. chalk/client/models.py +99 -2
  193. chalk/client/serialization/model_serialization.py +126 -10
  194. chalk/config/project_config.py +1 -1
  195. chalk/df/LazyFramePlaceholder.py +1154 -0
  196. chalk/df/ast_parser.py +2 -10
  197. chalk/features/_class_property.py +7 -0
  198. chalk/features/_embedding/embedding.py +1 -0
  199. chalk/features/_embedding/sentence_transformer.py +1 -1
  200. chalk/features/_encoding/converter.py +83 -2
  201. chalk/features/_encoding/pyarrow.py +20 -4
  202. chalk/features/_encoding/rich.py +1 -3
  203. chalk/features/_tensor.py +1 -2
  204. chalk/features/dataframe/_filters.py +14 -5
  205. chalk/features/dataframe/_impl.py +91 -36
  206. chalk/features/dataframe/_validation.py +11 -7
  207. chalk/features/feature_field.py +40 -30
  208. chalk/features/feature_set.py +1 -2
  209. chalk/features/feature_set_decorator.py +1 -0
  210. chalk/features/feature_wrapper.py +42 -3
  211. chalk/features/hooks.py +81 -12
  212. chalk/features/inference.py +65 -10
  213. chalk/features/resolver.py +338 -56
  214. chalk/features/tag.py +1 -3
  215. chalk/features/underscore_features.py +2 -1
  216. chalk/functions/__init__.py +456 -21
  217. chalk/functions/holidays.py +1 -3
  218. chalk/gitignore/gitignore_parser.py +5 -1
  219. chalk/importer.py +186 -74
  220. chalk/ml/__init__.py +6 -2
  221. chalk/ml/model_hooks.py +368 -51
  222. chalk/ml/model_reference.py +68 -10
  223. chalk/ml/model_version.py +34 -21
  224. chalk/ml/utils.py +143 -40
  225. chalk/operators/_utils.py +14 -3
  226. chalk/parsed/_proto/export.py +22 -0
  227. chalk/parsed/duplicate_input_gql.py +4 -0
  228. chalk/parsed/expressions.py +1 -3
  229. chalk/parsed/json_conversions.py +21 -14
  230. chalk/parsed/to_proto.py +16 -4
  231. chalk/parsed/user_types_to_json.py +31 -10
  232. chalk/parsed/validation_from_registries.py +182 -0
  233. chalk/queries/named_query.py +16 -6
  234. chalk/queries/scheduled_query.py +13 -1
  235. chalk/serialization/parsed_annotation.py +25 -12
  236. chalk/sql/__init__.py +221 -0
  237. chalk/sql/_internal/integrations/athena.py +6 -1
  238. chalk/sql/_internal/integrations/bigquery.py +22 -2
  239. chalk/sql/_internal/integrations/databricks.py +61 -18
  240. chalk/sql/_internal/integrations/mssql.py +281 -0
  241. chalk/sql/_internal/integrations/postgres.py +11 -3
  242. chalk/sql/_internal/integrations/redshift.py +4 -0
  243. chalk/sql/_internal/integrations/snowflake.py +11 -2
  244. chalk/sql/_internal/integrations/util.py +2 -1
  245. chalk/sql/_internal/sql_file_resolver.py +55 -10
  246. chalk/sql/_internal/sql_source.py +36 -2
  247. chalk/streams/__init__.py +1 -3
  248. chalk/streams/_kafka_source.py +5 -1
  249. chalk/streams/_windows.py +16 -4
  250. chalk/streams/types.py +1 -2
  251. chalk/utils/__init__.py +1 -3
  252. chalk/utils/_otel_version.py +13 -0
  253. chalk/utils/async_helpers.py +14 -5
  254. chalk/utils/df_utils.py +2 -2
  255. chalk/utils/duration.py +1 -3
  256. chalk/utils/job_log_display.py +538 -0
  257. chalk/utils/missing_dependency.py +5 -4
  258. chalk/utils/notebook.py +255 -2
  259. chalk/utils/pl_helpers.py +190 -37
  260. chalk/utils/pydanticutil/pydantic_compat.py +1 -2
  261. chalk/utils/storage_client.py +246 -0
  262. chalk/utils/threading.py +1 -3
  263. chalk/utils/tracing.py +194 -86
  264. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
  265. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
  266. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
  267. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
  268. {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/utils/notebook.py CHANGED
@@ -1,15 +1,21 @@
1
+ import ast
1
2
  import enum
2
3
  import functools
3
4
  import inspect
4
5
  import sys
5
6
  from contextvars import ContextVar
6
- from typing import TYPE_CHECKING, Any, Optional
7
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
7
8
 
8
9
  from chalk.utils.environment_parsing import env_var_bool
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from chalk.sql._internal.sql_file_resolver import SQLStringResult
12
13
 
14
+ try:
15
+ from ipython.core.interactiveshell import InteractiveShell # type: ignore
16
+ except ImportError:
17
+ InteractiveShell = Any # type: ignore
18
+
13
19
 
14
20
  def print_user_error(message: str, exception: Optional[Exception] = None, suggested_action: Optional[str] = None):
15
21
  print(f"\033[91mERROR: {message}\033[0m", file=sys.stderr)
@@ -29,7 +35,7 @@ class IPythonEvents(enum.Enum):
29
35
  POST_RUN_CELL = "post_run_cell"
30
36
 
31
37
 
32
- def get_ipython_or_none() -> Optional[object]:
38
+ def get_ipython_or_none() -> Optional[Any]:
33
39
  """
34
40
  Returns the global IPython shell object, if this code is running in an ipython environment.
35
41
  :return: An `IPython.core.interactiveshell.InteractiveShell`, or None if we're not running in a notebook/ipython repl
@@ -129,3 +135,250 @@ def register_resolver_from_cell_magic(sql_string_result: "SQLStringResult"):
129
135
  return
130
136
 
131
137
  NOTEBOOK_DEFINED_SQL_RESOLVERS[sql_string_result.path] = resolver_result
138
+
139
+
140
+ def is_valid_python_code(code_string: str):
141
+ try:
142
+ compile(code_string, "<string>", "exec")
143
+ return True
144
+ except (SyntaxError, ValueError):
145
+ return False
146
+
147
+
148
+ def _get_import_names(node: Union[ast.Import, ast.ImportFrom], cell_source: str, import_source: str) -> set[str]:
149
+ """Extract the names that an import statement brings into scope."""
150
+ import ast
151
+
152
+ imported_names = set()
153
+ if isinstance(node, ast.Import):
154
+ for alias in node.names:
155
+ name = alias.asname if alias.asname else alias.name
156
+ imported_names.add(name)
157
+ else: # ast.ImportFrom
158
+ for alias in node.names:
159
+ if alias.name == "*":
160
+ # Can't track wildcard imports precisely, so include the import text itself
161
+ imported_names.add(import_source)
162
+ else:
163
+ name = alias.asname if alias.asname else alias.name
164
+ imported_names.add(name)
165
+ return imported_names
166
+
167
+
168
+ def _parse_notebook_cells(cells: list[tuple[int, int, str]]):
169
+ """Parse notebook cells and extract definitions of functions, classes, globals, and imports."""
170
+ import ast
171
+
172
+ latest_function_def: dict[str, tuple[str, ast.AST]] = {} # name -> (source, ast_node)
173
+ latest_global_assign: dict[str, str] = {} # name -> source
174
+ latest_class_def: dict[str, tuple[str, ast.AST]] = {} # name -> (source, ast_node)
175
+ all_imports: dict[str, tuple[list[str], ast.AST]] = {} # import_text -> (names_imported, ast_node)
176
+
177
+ for _, _, cell_source in cells:
178
+ cell_source = cell_source.strip()
179
+ if not cell_source:
180
+ continue
181
+
182
+ try:
183
+ cell_tree = ast.parse(cell_source)
184
+ except SyntaxError:
185
+ continue
186
+
187
+ for node in cell_tree.body:
188
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
189
+ import_source = ast.get_source_segment(cell_source, node)
190
+ if import_source is None:
191
+ continue
192
+ imported_names = _get_import_names(node, cell_source, import_source)
193
+ all_imports[import_source] = (list(imported_names), node)
194
+
195
+ elif isinstance(node, ast.FunctionDef):
196
+ func_source = ast.get_source_segment(cell_source, node)
197
+ if func_source is not None:
198
+ latest_function_def[node.name] = (func_source, node)
199
+
200
+ elif isinstance(node, ast.ClassDef):
201
+ class_source = ast.get_source_segment(cell_source, node)
202
+ if class_source is not None:
203
+ latest_class_def[node.name] = (class_source, node)
204
+
205
+ elif isinstance(node, ast.Assign):
206
+ for target in node.targets:
207
+ if isinstance(target, ast.Name):
208
+ assign_source = ast.get_source_segment(cell_source, node)
209
+ if assign_source is not None:
210
+ latest_global_assign[target.id] = assign_source
211
+
212
+ return latest_function_def, latest_class_def, latest_global_assign, all_imports
213
+
214
+
215
+ def _get_referenced_names(source_code: str) -> set[str]:
216
+ """Extract all names referenced in source code."""
217
+ import ast
218
+
219
+ try:
220
+ tree = ast.parse(source_code)
221
+ except SyntaxError:
222
+ return set()
223
+
224
+ names = set()
225
+ for node in ast.walk(tree):
226
+ if isinstance(node, ast.Name):
227
+ names.add(node.id)
228
+ elif isinstance(node, ast.Attribute):
229
+ # For module.function, capture the base module
230
+ if isinstance(node.value, ast.Name):
231
+ names.add(node.value.id)
232
+ return names
233
+
234
+
235
+ def _collect_dependencies(
236
+ fn_source: str,
237
+ fn_name: str,
238
+ latest_function_def: dict[str, tuple[str, ast.AST]],
239
+ latest_class_def: dict[str, tuple[str, ast.AST]],
240
+ latest_global_assign: dict[str, str],
241
+ builtin_names: set[str],
242
+ ):
243
+ """Recursively collect all dependencies needed by the function."""
244
+ # maps name -> source
245
+ needed_functions: dict[str, str] = {}
246
+ needed_classes: dict[str, str] = {}
247
+ needed_globals: dict[str, str] = {}
248
+ needed_names: set[str] = set()
249
+
250
+ to_process = [fn_source]
251
+ processed = set()
252
+
253
+ while to_process:
254
+ current_source = to_process.pop()
255
+ if current_source in processed:
256
+ continue
257
+ processed.add(current_source)
258
+
259
+ referenced = _get_referenced_names(current_source)
260
+ referenced = referenced - builtin_names - {fn_name}
261
+ needed_names.update(referenced)
262
+
263
+ for name in referenced:
264
+ # Check if it's a class we defined
265
+ if name in latest_class_def and name not in needed_classes:
266
+ class_source, _ = latest_class_def[name]
267
+ needed_classes[name] = class_source
268
+ to_process.append(class_source)
269
+
270
+ # Check if it's a function we defined
271
+ elif name in latest_function_def and name not in needed_functions:
272
+ func_source, _ = latest_function_def[name]
273
+ needed_functions[name] = func_source
274
+ to_process.append(func_source)
275
+
276
+ for name in referenced:
277
+ # Check if it's a global variable we defined
278
+ if name in latest_global_assign and name not in needed_globals:
279
+ assign_source = latest_global_assign[name]
280
+ needed_globals[name] = assign_source
281
+ to_process.append(assign_source)
282
+
283
+ return needed_functions, needed_classes, needed_globals, needed_names
284
+
285
+
286
+ def _filter_imports(all_imports: dict[str, tuple[list[str], ast.AST]], needed_names: set[str]) -> list[str]:
287
+ """Filter imports to only include those that are actually used."""
288
+ needed_imports: list[str] = []
289
+ for import_text, (imported_names, _) in all_imports.items():
290
+ if any(name in needed_names or name == import_text for name in imported_names):
291
+ needed_imports.append(import_text)
292
+ return needed_imports
293
+
294
+
295
+ def _build_script(
296
+ fn_source: str,
297
+ fn_name: str,
298
+ needed_imports: list[str],
299
+ needed_globals: dict[str, str],
300
+ needed_classes: dict[str, str],
301
+ needed_functions: dict[str, str],
302
+ ) -> str:
303
+ """Build the final script from collected components."""
304
+ script_parts: list[str] = []
305
+
306
+ if needed_imports:
307
+ script_parts.extend(needed_imports)
308
+ script_parts.append("")
309
+
310
+ if needed_globals:
311
+ script_parts.extend(needed_globals.values())
312
+ script_parts.append("")
313
+
314
+ if needed_classes:
315
+ script_parts.extend(needed_classes.values())
316
+ script_parts.append("")
317
+
318
+ if needed_functions:
319
+ script_parts.extend(needed_functions.values())
320
+ script_parts.append("")
321
+
322
+ script_parts.append(fn_source)
323
+
324
+ return "\n".join(script_parts)
325
+
326
+
327
+ def parse_notebook_into_script(fn: Callable[[], None], takes_argument: bool) -> str:
328
+ """
329
+ Parse a notebook function and its dependencies into a standalone Python script.
330
+
331
+ The function must take no inputs and produce no outputs. The output script will
332
+ call fn() in __main__ and include all necessary imports, globals, and helper
333
+ functions that have been executed in the notebook.
334
+
335
+ Args:
336
+ fn (Callable[[], None]): A callable with no parameters and no return value.
337
+
338
+ Returns:
339
+ str: A Python script as a string.
340
+ """
341
+ import builtins
342
+
343
+ if not is_notebook():
344
+ raise RuntimeError("parse_notebook_into_script should only be called from a notebook environment.")
345
+
346
+ sig = inspect.signature(fn)
347
+ if len(sig.parameters) != int(takes_argument):
348
+ raise ValueError(
349
+ f"Function {fn.__name__} must take {int(takes_argument)} inputs, but has parameters: {list(sig.parameters.keys())}"
350
+ )
351
+
352
+ shell = get_ipython_or_none()
353
+ if shell is None:
354
+ raise RuntimeError("Could not access IPython shell")
355
+
356
+ # Get the cell contents of executed cells
357
+ if getattr(shell, "history_manager", None) is None:
358
+ raise RuntimeError("Could not access IPython history manager")
359
+
360
+ history_manager = shell.history_manager
361
+ session_number = history_manager.get_last_session_id()
362
+ cells = list(history_manager.get_range(session=session_number, start=1))
363
+
364
+ # Parse cells to extract definitions
365
+ latest_function_def, latest_class_def, latest_global_assign, all_imports = _parse_notebook_cells(cells)
366
+
367
+ # Get function source and collect dependencies
368
+ fn_source = inspect.getsource(fn)
369
+ builtin_names = set(dir(builtins))
370
+
371
+ needed_functions, needed_classes, needed_globals, needed_names = _collect_dependencies(
372
+ fn_source, fn.__name__, latest_function_def, latest_class_def, latest_global_assign, builtin_names
373
+ )
374
+
375
+ # Filter imports to only used ones
376
+ needed_imports = _filter_imports(all_imports, needed_names)
377
+
378
+ # Build and return the script
379
+ script = _build_script(fn_source, fn.__name__, needed_imports, needed_globals, needed_classes, needed_functions)
380
+
381
+ if not is_valid_python_code(script):
382
+ raise RuntimeError("Error generating valid training function from notebook")
383
+
384
+ return script
chalk/utils/pl_helpers.py CHANGED
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import itertools
4
+ import zoneinfo
4
5
  from datetime import timedelta
5
- from typing import TYPE_CHECKING, Any, Iterator, TypeVar
6
+ from typing import TYPE_CHECKING, Any, Iterator, TypeGuard, TypeVar, overload
6
7
 
7
8
  import pyarrow as pa
8
- import zoneinfo
9
9
  from packaging.version import parse
10
- from typing_extensions import TypeGuard
11
10
 
12
11
  from chalk.utils.log_with_context import get_logger
13
12
  from chalk.utils.missing_dependency import missing_dependency_exception
@@ -27,6 +26,13 @@ except ImportError:
27
26
  json_loads = json.loads
28
27
 
29
28
 
29
+ def json_loads_as_str(x: str | None):
30
+ if x is None:
31
+ return None
32
+ x = json_loads(x)
33
+ return x if x is None else str(x)
34
+
35
+
30
36
  def is_version_gte(version: str, target: str) -> bool:
31
37
  return parse(version) >= parse(target)
32
38
 
@@ -36,9 +42,46 @@ try:
36
42
 
37
43
  is_new_polars = is_version_gte(pl.__version__, "0.18.0")
38
44
  polars_has_pad_start = is_version_gte(pl.__version__, "0.19.12")
45
+ polars_array_uses_shape = is_version_gte(pl.__version__, "1.0.0")
46
+ polars_uses_schema_overrides = is_version_gte(pl.__version__, "0.20.31")
47
+ polars_join_ignores_nulls = is_version_gte(pl.__version__, "0.20.0")
48
+ polars_broken_concat_on_nested_list = is_version_gte(pl.__version__, "1.0.0")
49
+ polars_group_by_instead_of_groupby = is_version_gte(pl.__version__, "1.0.0")
50
+ polars_name_dot_suffix_instead_of_suffix = is_version_gte(pl.__version__, "1.0.0")
51
+ polars_lazy_frame_collect_schema = is_version_gte(pl.__version__, "1.0.0")
52
+ polars_allow_lit_empty_struct = is_version_gte(pl.__version__, "1.0.0")
39
53
  except ImportError:
40
54
  is_new_polars = False
41
55
  polars_has_pad_start = False
56
+ polars_array_uses_shape = False
57
+ polars_uses_schema_overrides = False
58
+ polars_join_ignores_nulls = False
59
+ polars_broken_concat_on_nested_list = False
60
+ polars_group_by_instead_of_groupby = False
61
+ polars_name_dot_suffix_instead_of_suffix = False
62
+ polars_lazy_frame_collect_schema = False
63
+ polars_allow_lit_empty_struct = False
64
+
65
+
66
+ def pl_array(inner: pl.PolarsDataType, size: int) -> pl.Array:
67
+ """Create a Polars Array type with version-compatible parameter names.
68
+
69
+ Args:
70
+ inner: The inner data type of the array
71
+ size: The fixed size of the array
72
+
73
+ Returns:
74
+ A Polars Array type
75
+ """
76
+ try:
77
+ import polars as pl
78
+ except ImportError:
79
+ raise missing_dependency_exception("chalkpy[runtime]")
80
+
81
+ if polars_array_uses_shape:
82
+ return pl.Array(inner=inner, shape=size)
83
+ else:
84
+ return pl.Array(inner=inner, width=size) # type: ignore[call-arg]
42
85
 
43
86
 
44
87
  def chunked_df_slices(df: pl.LazyFrame | pl.DataFrame, chunk_size: int) -> Iterator[pl.DataFrame]:
@@ -100,13 +143,13 @@ def pl_datetime_to_iso_string(expr: pl.Expr, tz_key: str | None) -> pl.Expr:
100
143
  else:
101
144
  return pl.format(
102
145
  "{}-{}-{}T{}:{}:{}.{}" + timezone,
103
- expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"),
104
- expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"),
105
- expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"),
106
- expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"),
107
- expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"),
108
- expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"),
109
- expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"),
146
+ expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"), # pyright: ignore -- polars backcompat
147
+ expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
148
+ expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
149
+ expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
150
+ expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
151
+ expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
152
+ expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"), # pyright: ignore -- polars backcompat
110
153
  )
111
154
 
112
155
 
@@ -126,9 +169,9 @@ def pl_date_to_iso_string(expr: pl.Expr) -> pl.Expr:
126
169
  else:
127
170
  return pl.format(
128
171
  "{}-{}-{}",
129
- expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"),
130
- expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"),
131
- expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"),
172
+ expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"), # pyright: ignore -- polars backcompat
173
+ expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
174
+ expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
132
175
  )
133
176
 
134
177
 
@@ -149,21 +192,39 @@ def pl_time_to_iso_string(expr: pl.Expr) -> pl.Expr:
149
192
  else:
150
193
  return pl.format(
151
194
  "{}:{}:{}.{}",
152
- expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"),
153
- expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"),
154
- expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"),
155
- expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"),
195
+ expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
196
+ expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
197
+ expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
198
+ expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"), # pyright: ignore -- polars backcompat
199
+ )
200
+
201
+
202
+ def pl_dtype_swap(dtype: pl.PolarsDataType, _from: pl.PolarsDataType, to: pl.PolarsDataType) -> pl.PolarsDataType:
203
+ if isinstance(dtype, _from):
204
+ return to
205
+ if isinstance(dtype, pl.List):
206
+ return pl.List(inner=pl_dtype_swap(dtype.inner, _from, to))
207
+ if isinstance(dtype, pl.Struct):
208
+ return pl.Struct(
209
+ {field_name: pl_dtype_swap(field_dtype, _from, to) for field_name, field_dtype in dtype.to_schema().items()}
156
210
  )
211
+ return dtype
157
212
 
158
213
 
159
- def pl_json_decode(series: pl.Series, dtype: pl.PolarsDataType | None = None) -> pl.Series:
214
+ def pl_json_decode(series: pl.Series, dtype: pl.PolarsDataType) -> pl.Series:
160
215
  if is_new_polars:
161
- decoded_series = series.map_elements(json_loads, return_dtype=dtype) # pyright: ignore -- polars backcompat
216
+ swapped_dtype = pl_dtype_swap(dtype, pl.Binary, pl.Utf8)
217
+ if swapped_dtype == pl.Utf8:
218
+ decoded_series = series.map_elements(json_loads_as_str, return_dtype=swapped_dtype).cast(
219
+ dtype
220
+ ) # pyright: ignore -- polars backcompat
221
+ else:
222
+ decoded_series = series.map_elements(json_loads, return_dtype=swapped_dtype).cast(
223
+ dtype
224
+ ) # pyright: ignore -- polars backcompat
162
225
  else:
163
- decoded_series = series.apply(json_loads, return_dtype=dtype)
164
- if dtype is not None:
165
- # Special case -- for nested dtypes polars doesn't always respect the return_dtype
166
- decoded_series = decoded_series.cast(dtype)
226
+ decoded_series = series.apply(json_loads, return_dtype=dtype) # pyright: ignore -- polars backcompat
227
+ decoded_series = decoded_series.cast(dtype)
167
228
  return decoded_series
168
229
 
169
230
 
@@ -174,19 +235,33 @@ def pl_duration_to_iso_string(expr: pl.Expr) -> pl.Expr:
174
235
  except ImportError:
175
236
  raise missing_dependency_exception("chalkpy[runtime]")
176
237
 
177
- return pl.format(
178
- "{}P{}DT{}H{}M{}.{}S",
179
- pl.when(expr.dt.microseconds() < 0).then(pl.lit("-")).otherwise(pl.lit("")),
180
- expr.dt.days().abs().cast(pl.Utf8),
181
- (expr.dt.hours().abs() % 24).cast(pl.Utf8),
182
- (expr.dt.minutes().abs() % 60).cast(pl.Utf8),
183
- (expr.dt.seconds().abs() % 60).cast(pl.Utf8),
184
- (expr.dt.microseconds().abs() % 1_000_000)
185
- .cast(pl.Utf8)
186
- .str.pad_start(6, "0") # pyright: ignore -- polars backcompat
187
- if is_new_polars
188
- else (expr.dt.microseconds().abs() % 1_000_000).cast(pl.Utf8).str.rjust(6, "0"),
189
- )
238
+ try:
239
+ return pl.format(
240
+ "{}P{}DT{}H{}M{}.{}S",
241
+ pl.when(expr.dt.microseconds() < 0) # pyright: ignore -- polars backcompat
242
+ .then(pl.lit("-"))
243
+ .otherwise(pl.lit("")), # pyright: ignore -- polars backcompat
244
+ expr.dt.days().abs().cast(pl.Utf8), # pyright: ignore -- polars backcompat
245
+ (expr.dt.hours().abs() % 24).cast(pl.Utf8), # pyright: ignore -- polars backcompat
246
+ (expr.dt.minutes().abs() % 60).cast(pl.Utf8), # pyright: ignore -- polars backcompat
247
+ (expr.dt.seconds().abs() % 60).cast(pl.Utf8), # pyright: ignore -- polars backcompat
248
+ (expr.dt.microseconds().abs() % 1_000_000) # pyright: ignore -- polars backcompat
249
+ .cast(pl.Utf8)
250
+ .str.pad_start(6, "0") # pyright: ignore -- polars backcompat
251
+ if is_new_polars
252
+ else (expr.dt.microseconds().abs() % 1_000_000) # pyright: ignore -- polars backcompat
253
+ .cast(pl.Utf8)
254
+ .str.rjust(6, "0"), # pyright: ignore -- polars backcompat
255
+ )
256
+ except AttributeError:
257
+ return (
258
+ pl.format("{}P{}DT{}H{}M{}.{}S", expr.dt.total_microseconds().abs() % 1_000_000)
259
+ .cast(pl.Utf8)
260
+ .str.pad_start(
261
+ 6,
262
+ "0",
263
+ )
264
+ )
190
265
 
191
266
 
192
267
  def pl_json_encode(expr: pl.Expr, dtype: pl.PolarsDataType):
@@ -374,7 +449,7 @@ def _json_encode_inner(expr: pl.Expr, dtype: pl.PolarsDataType) -> pl.Expr:
374
449
  _backup_json_encode, return_dtype=pl.Utf8
375
450
  )
376
451
  else:
377
- return expr.apply(_backup_json_encode, return_dtype=pl.Utf8)
452
+ return expr.apply(_backup_json_encode, return_dtype=pl.Utf8) # pyright: ignore -- polars backcompat
378
453
  expr = expr.fill_null([])
379
454
  lists_with_extra_none = (
380
455
  expr.list.concat(pl.lit(None)) # pyright: ignore -- back compat
@@ -469,3 +544,81 @@ def recursively_has_struct(dtype: pa.DataType) -> bool:
469
544
  assert isinstance(dtype, pa.MapType)
470
545
  return recursively_has_struct(dtype.key_type) or recursively_has_struct(dtype.item_type)
471
546
  return False
547
+
548
+
549
+ def apply_compat(
550
+ expr: "pl.Expr",
551
+ function: Any,
552
+ return_dtype: "pl.PolarsDataType | None" = None,
553
+ **kwargs: Any,
554
+ ) -> "pl.Expr":
555
+ """
556
+ Apply a custom function to an expression in a version-compatible way.
557
+
558
+ In Polars >= 0.19, expr.apply() was deprecated in favor of expr.map_elements().
559
+ This function provides compatibility between versions.
560
+
561
+ Args:
562
+ expr: The Polars expression to apply the function to
563
+ function: The function to apply to each element
564
+ return_dtype: The return data type for the expression (optional)
565
+ **kwargs: Additional keyword arguments to pass to the underlying method
566
+
567
+ Returns:
568
+ A Polars expression with the function applied
569
+
570
+ Example:
571
+ >>> import polars as pl
572
+ >>> from chalkengine.utils.polars_compat_util import apply_compat
573
+ >>> df = pl.DataFrame({"a": [1, 2, 3]})
574
+ >>> df.select(apply_compat(pl.col("a"), lambda x: x * 2))
575
+ """
576
+ # Build kwargs for the call
577
+ call_kwargs = kwargs.copy()
578
+ if return_dtype is not None:
579
+ call_kwargs["return_dtype"] = return_dtype
580
+
581
+ try:
582
+ # Try newer API first: map_elements()
583
+ return expr.map_elements(function, **call_kwargs) # type: ignore
584
+ except AttributeError:
585
+ # Fall back to older API: apply()
586
+ return expr.apply(function, **call_kwargs) # type: ignore
587
+
588
+
589
+ @overload
590
+ def str_json_decode_compat(expr: "pl.Expr", dtype: "pl.PolarsDataType") -> "pl.Expr":
591
+ ...
592
+
593
+
594
+ @overload
595
+ def str_json_decode_compat(expr: "pl.Series", dtype: "pl.PolarsDataType") -> "pl.Series":
596
+ ...
597
+
598
+
599
+ def str_json_decode_compat(expr: "pl.Expr | pl.Series", dtype: "pl.PolarsDataType") -> "pl.Expr | pl.Series":
600
+ """
601
+ Parse/decode JSON strings in a version-compatible way.
602
+
603
+ In newer Polars versions (>= 1.0), str.json_extract() was renamed to str.json_decode().
604
+ This function provides compatibility between versions.
605
+
606
+ Args:
607
+ expr: The Polars expression containing JSON strings to parse
608
+ dtype: The Polars data type to extract to
609
+
610
+ Returns:
611
+ A Polars expression that parses the JSON strings
612
+ """
613
+ try:
614
+ # Try newer API first: str.json_decode()
615
+ return expr.str.json_decode(dtype=dtype) # type: ignore
616
+ except AttributeError:
617
+ # Fall back to older API: str.json_extract()
618
+ return expr.str.json_extract(dtype=dtype) # type: ignore
619
+
620
+
621
+ def schema_compat(df: "pl.DataFrame | pl.LazyFrame"):
622
+ if polars_lazy_frame_collect_schema and isinstance(df, pl.LazyFrame):
623
+ return df.collect_schema()
624
+ return df.schema
@@ -2,12 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  from inspect import isclass
5
- from typing import Any
5
+ from typing import Any, TypeGuard
6
6
 
7
7
  import pydantic
8
8
  from packaging import version
9
9
  from pydantic import BaseModel
10
- from typing_extensions import TypeGuard
11
10
 
12
11
  try:
13
12
  from pydantic.v1 import BaseModel as V1BaseModel