chalkpy 2.89.22__py3-none-any.whl → 2.95.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +2 -1
- chalk/_gen/chalk/arrow/v1/arrow_pb2.py +7 -5
- chalk/_gen/chalk/arrow/v1/arrow_pb2.pyi +6 -0
- chalk/_gen/chalk/artifacts/v1/chart_pb2.py +36 -33
- chalk/_gen/chalk/artifacts/v1/chart_pb2.pyi +41 -1
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.py +8 -7
- chalk/_gen/chalk/artifacts/v1/cron_query_pb2.pyi +5 -0
- chalk/_gen/chalk/common/v1/offline_query_pb2.py +19 -13
- chalk/_gen/chalk/common/v1/offline_query_pb2.pyi +37 -0
- chalk/_gen/chalk/common/v1/online_query_pb2.py +54 -54
- chalk/_gen/chalk/common/v1/online_query_pb2.pyi +13 -1
- chalk/_gen/chalk/common/v1/script_task_pb2.py +13 -11
- chalk/_gen/chalk/common/v1/script_task_pb2.pyi +19 -1
- chalk/_gen/chalk/dataframe/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/__init__.py +0 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.py +48 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2.pyi +123 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.py +4 -0
- chalk/_gen/chalk/dataframe/v1/dataframe_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/graph/v1/graph_pb2.py +150 -149
- chalk/_gen/chalk/graph/v1/graph_pb2.pyi +25 -0
- chalk/_gen/chalk/graph/v1/sources_pb2.py +94 -84
- chalk/_gen/chalk/graph/v1/sources_pb2.pyi +56 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.py +79 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2.pyi +377 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.py +4 -0
- chalk/_gen/chalk/kubernetes/v1/horizontalpodautoscaler_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.py +43 -7
- chalk/_gen/chalk/kubernetes/v1/scaledobject_pb2.pyi +252 -2
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.py +54 -27
- chalk/_gen/chalk/protosql/v1/sql_service_pb2.pyi +131 -3
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.py +45 -0
- chalk/_gen/chalk/protosql/v1/sql_service_pb2_grpc.pyi +14 -0
- chalk/_gen/chalk/python/v1/types_pb2.py +14 -14
- chalk/_gen/chalk/python/v1/types_pb2.pyi +8 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.py +76 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2.pyi +156 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.py +258 -0
- chalk/_gen/chalk/server/v1/benchmark_pb2_grpc.pyi +84 -0
- chalk/_gen/chalk/server/v1/billing_pb2.py +40 -38
- chalk/_gen/chalk/server/v1/billing_pb2.pyi +17 -1
- chalk/_gen/chalk/server/v1/branches_pb2.py +45 -0
- chalk/_gen/chalk/server/v1/branches_pb2.pyi +80 -0
- chalk/_gen/chalk/server/v1/branches_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/builder_pb2.py +372 -272
- chalk/_gen/chalk/server/v1/builder_pb2.pyi +479 -12
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.py +360 -0
- chalk/_gen/chalk/server/v1/builder_pb2_grpc.pyi +96 -0
- chalk/_gen/chalk/server/v1/chart_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/chart_pb2.pyi +18 -2
- chalk/_gen/chalk/server/v1/clickhouse_pb2.py +42 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2.pyi +17 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/clickhouse_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2.py +153 -107
- chalk/_gen/chalk/server/v1/cloud_components_pb2.pyi +146 -4
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/cloud_components_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.py +11 -3
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2.pyi +20 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/cloud_credentials_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.py +59 -35
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2.pyi +127 -1
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/dataplanejobqueue_pb2_grpc.pyi +36 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.py +90 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2.pyi +264 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.py +170 -0
- chalk/_gen/chalk/server/v1/dataplaneworkflows_pb2_grpc.pyi +62 -0
- chalk/_gen/chalk/server/v1/datasets_pb2.py +36 -24
- chalk/_gen/chalk/server/v1/datasets_pb2.pyi +71 -2
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/datasets_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/deploy_pb2.pyi +12 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/deploy_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/deployment_pb2.py +20 -15
- chalk/_gen/chalk/server/v1/deployment_pb2.pyi +25 -0
- chalk/_gen/chalk/server/v1/environment_pb2.py +25 -15
- chalk/_gen/chalk/server/v1/environment_pb2.pyi +93 -1
- chalk/_gen/chalk/server/v1/eventbus_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2.pyi +64 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/eventbus_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/files_pb2.py +65 -0
- chalk/_gen/chalk/server/v1/files_pb2.pyi +167 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/files_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/graph_pb2.py +41 -3
- chalk/_gen/chalk/server/v1/graph_pb2.pyi +191 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.py +92 -0
- chalk/_gen/chalk/server/v1/graph_pb2_grpc.pyi +32 -0
- chalk/_gen/chalk/server/v1/incident_pb2.py +57 -0
- chalk/_gen/chalk/server/v1/incident_pb2.pyi +165 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.py +4 -0
- chalk/_gen/chalk/server/v1/incident_pb2_grpc.pyi +4 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.py +44 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2.pyi +38 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.py +78 -0
- chalk/_gen/chalk/server/v1/indexing_job_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/integrations_pb2.py +11 -9
- chalk/_gen/chalk/server/v1/integrations_pb2.pyi +34 -2
- chalk/_gen/chalk/server/v1/kube_pb2.py +29 -19
- chalk/_gen/chalk/server/v1/kube_pb2.pyi +28 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/kube_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/log_pb2.py +21 -3
- chalk/_gen/chalk/server/v1/log_pb2.pyi +68 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/log_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.py +73 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2.pyi +212 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.py +217 -0
- chalk/_gen/chalk/server/v1/metadataplanejobqueue_pb2_grpc.pyi +74 -0
- chalk/_gen/chalk/server/v1/model_registry_pb2.py +10 -10
- chalk/_gen/chalk/server/v1/model_registry_pb2.pyi +4 -1
- chalk/_gen/chalk/server/v1/monitoring_pb2.py +84 -75
- chalk/_gen/chalk/server/v1/monitoring_pb2.pyi +1 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.py +136 -0
- chalk/_gen/chalk/server/v1/monitoring_pb2_grpc.pyi +38 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2.py +32 -10
- chalk/_gen/chalk/server/v1/offline_queries_pb2.pyi +73 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/offline_queries_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.py +53 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2.pyi +86 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.py +168 -0
- chalk/_gen/chalk/server/v1/plandebug_pb2_grpc.pyi +60 -0
- chalk/_gen/chalk/server/v1/queries_pb2.py +76 -48
- chalk/_gen/chalk/server/v1/queries_pb2.pyi +155 -2
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.py +180 -0
- chalk/_gen/chalk/server/v1/queries_pb2_grpc.pyi +48 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2.py +4 -2
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/scheduled_query_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.py +12 -6
- chalk/_gen/chalk/server/v1/scheduled_query_run_pb2.pyi +75 -2
- chalk/_gen/chalk/server/v1/scheduler_pb2.py +24 -12
- chalk/_gen/chalk/server/v1/scheduler_pb2.pyi +61 -1
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/scheduler_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2.py +26 -14
- chalk/_gen/chalk/server/v1/script_tasks_pb2.pyi +33 -3
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/script_tasks_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.py +75 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2.pyi +142 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.py +349 -0
- chalk/_gen/chalk/server/v1/sql_interface_pb2_grpc.pyi +114 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.py +48 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2.pyi +150 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.py +123 -0
- chalk/_gen/chalk/server/v1/sql_queries_pb2_grpc.pyi +52 -0
- chalk/_gen/chalk/server/v1/team_pb2.py +156 -137
- chalk/_gen/chalk/server/v1/team_pb2.pyi +56 -10
- chalk/_gen/chalk/server/v1/team_pb2_grpc.py +90 -0
- chalk/_gen/chalk/server/v1/team_pb2_grpc.pyi +24 -0
- chalk/_gen/chalk/server/v1/topic_pb2.py +5 -3
- chalk/_gen/chalk/server/v1/topic_pb2.pyi +10 -1
- chalk/_gen/chalk/server/v1/trace_pb2.py +50 -28
- chalk/_gen/chalk/server/v1/trace_pb2.pyi +121 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.py +135 -0
- chalk/_gen/chalk/server/v1/trace_pb2_grpc.pyi +42 -0
- chalk/_gen/chalk/server/v1/webhook_pb2.py +9 -3
- chalk/_gen/chalk/server/v1/webhook_pb2.pyi +18 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.py +45 -0
- chalk/_gen/chalk/server/v1/webhook_pb2_grpc.pyi +12 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.py +62 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2.pyi +75 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.py +221 -0
- chalk/_gen/chalk/streaming/v1/debug_service_pb2_grpc.pyi +88 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.py +19 -7
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2.pyi +96 -3
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.py +48 -0
- chalk/_gen/chalk/streaming/v1/simple_streaming_service_pb2_grpc.pyi +20 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.py +32 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2.pyi +42 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.py +4 -0
- chalk/_gen/chalk/utils/v1/field_change_pb2_grpc.pyi +4 -0
- chalk/_lsp/error_builder.py +11 -0
- chalk/_monitoring/Chart.py +1 -3
- chalk/_version.py +1 -1
- chalk/cli.py +5 -10
- chalk/client/client.py +178 -64
- chalk/client/client_async.py +154 -0
- chalk/client/client_async_impl.py +22 -0
- chalk/client/client_grpc.py +738 -112
- chalk/client/client_impl.py +541 -136
- chalk/client/dataset.py +27 -6
- chalk/client/models.py +99 -2
- chalk/client/serialization/model_serialization.py +126 -10
- chalk/config/project_config.py +1 -1
- chalk/df/LazyFramePlaceholder.py +1154 -0
- chalk/df/ast_parser.py +2 -10
- chalk/features/_class_property.py +7 -0
- chalk/features/_embedding/embedding.py +1 -0
- chalk/features/_embedding/sentence_transformer.py +1 -1
- chalk/features/_encoding/converter.py +83 -2
- chalk/features/_encoding/pyarrow.py +20 -4
- chalk/features/_encoding/rich.py +1 -3
- chalk/features/_tensor.py +1 -2
- chalk/features/dataframe/_filters.py +14 -5
- chalk/features/dataframe/_impl.py +91 -36
- chalk/features/dataframe/_validation.py +11 -7
- chalk/features/feature_field.py +40 -30
- chalk/features/feature_set.py +1 -2
- chalk/features/feature_set_decorator.py +1 -0
- chalk/features/feature_wrapper.py +42 -3
- chalk/features/hooks.py +81 -12
- chalk/features/inference.py +65 -10
- chalk/features/resolver.py +338 -56
- chalk/features/tag.py +1 -3
- chalk/features/underscore_features.py +2 -1
- chalk/functions/__init__.py +456 -21
- chalk/functions/holidays.py +1 -3
- chalk/gitignore/gitignore_parser.py +5 -1
- chalk/importer.py +186 -74
- chalk/ml/__init__.py +6 -2
- chalk/ml/model_hooks.py +368 -51
- chalk/ml/model_reference.py +68 -10
- chalk/ml/model_version.py +34 -21
- chalk/ml/utils.py +143 -40
- chalk/operators/_utils.py +14 -3
- chalk/parsed/_proto/export.py +22 -0
- chalk/parsed/duplicate_input_gql.py +4 -0
- chalk/parsed/expressions.py +1 -3
- chalk/parsed/json_conversions.py +21 -14
- chalk/parsed/to_proto.py +16 -4
- chalk/parsed/user_types_to_json.py +31 -10
- chalk/parsed/validation_from_registries.py +182 -0
- chalk/queries/named_query.py +16 -6
- chalk/queries/scheduled_query.py +13 -1
- chalk/serialization/parsed_annotation.py +25 -12
- chalk/sql/__init__.py +221 -0
- chalk/sql/_internal/integrations/athena.py +6 -1
- chalk/sql/_internal/integrations/bigquery.py +22 -2
- chalk/sql/_internal/integrations/databricks.py +61 -18
- chalk/sql/_internal/integrations/mssql.py +281 -0
- chalk/sql/_internal/integrations/postgres.py +11 -3
- chalk/sql/_internal/integrations/redshift.py +4 -0
- chalk/sql/_internal/integrations/snowflake.py +11 -2
- chalk/sql/_internal/integrations/util.py +2 -1
- chalk/sql/_internal/sql_file_resolver.py +55 -10
- chalk/sql/_internal/sql_source.py +36 -2
- chalk/streams/__init__.py +1 -3
- chalk/streams/_kafka_source.py +5 -1
- chalk/streams/_windows.py +16 -4
- chalk/streams/types.py +1 -2
- chalk/utils/__init__.py +1 -3
- chalk/utils/_otel_version.py +13 -0
- chalk/utils/async_helpers.py +14 -5
- chalk/utils/df_utils.py +2 -2
- chalk/utils/duration.py +1 -3
- chalk/utils/job_log_display.py +538 -0
- chalk/utils/missing_dependency.py +5 -4
- chalk/utils/notebook.py +255 -2
- chalk/utils/pl_helpers.py +190 -37
- chalk/utils/pydanticutil/pydantic_compat.py +1 -2
- chalk/utils/storage_client.py +246 -0
- chalk/utils/threading.py +1 -3
- chalk/utils/tracing.py +194 -86
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/METADATA +53 -21
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/RECORD +268 -198
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/WHEEL +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/entry_points.txt +0 -0
- {chalkpy-2.89.22.dist-info → chalkpy-2.95.3.dist-info}/top_level.txt +0 -0
chalk/utils/notebook.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
|
+
import ast
|
|
1
2
|
import enum
|
|
2
3
|
import functools
|
|
3
4
|
import inspect
|
|
4
5
|
import sys
|
|
5
6
|
from contextvars import ContextVar
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
7
8
|
|
|
8
9
|
from chalk.utils.environment_parsing import env_var_bool
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
from chalk.sql._internal.sql_file_resolver import SQLStringResult
|
|
12
13
|
|
|
14
|
+
try:
|
|
15
|
+
from ipython.core.interactiveshell import InteractiveShell # type: ignore
|
|
16
|
+
except ImportError:
|
|
17
|
+
InteractiveShell = Any # type: ignore
|
|
18
|
+
|
|
13
19
|
|
|
14
20
|
def print_user_error(message: str, exception: Optional[Exception] = None, suggested_action: Optional[str] = None):
|
|
15
21
|
print(f"\033[91mERROR: {message}\033[0m", file=sys.stderr)
|
|
@@ -29,7 +35,7 @@ class IPythonEvents(enum.Enum):
|
|
|
29
35
|
POST_RUN_CELL = "post_run_cell"
|
|
30
36
|
|
|
31
37
|
|
|
32
|
-
def get_ipython_or_none() -> Optional[
|
|
38
|
+
def get_ipython_or_none() -> Optional[Any]:
|
|
33
39
|
"""
|
|
34
40
|
Returns the global IPython shell object, if this code is running in an ipython environment.
|
|
35
41
|
:return: An `IPython.core.interactiveshell.InteractiveShell`, or None if we're not running in a notebook/ipython repl
|
|
@@ -129,3 +135,250 @@ def register_resolver_from_cell_magic(sql_string_result: "SQLStringResult"):
|
|
|
129
135
|
return
|
|
130
136
|
|
|
131
137
|
NOTEBOOK_DEFINED_SQL_RESOLVERS[sql_string_result.path] = resolver_result
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def is_valid_python_code(code_string: str):
|
|
141
|
+
try:
|
|
142
|
+
compile(code_string, "<string>", "exec")
|
|
143
|
+
return True
|
|
144
|
+
except (SyntaxError, ValueError):
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _get_import_names(node: Union[ast.Import, ast.ImportFrom], cell_source: str, import_source: str) -> set[str]:
|
|
149
|
+
"""Extract the names that an import statement brings into scope."""
|
|
150
|
+
import ast
|
|
151
|
+
|
|
152
|
+
imported_names = set()
|
|
153
|
+
if isinstance(node, ast.Import):
|
|
154
|
+
for alias in node.names:
|
|
155
|
+
name = alias.asname if alias.asname else alias.name
|
|
156
|
+
imported_names.add(name)
|
|
157
|
+
else: # ast.ImportFrom
|
|
158
|
+
for alias in node.names:
|
|
159
|
+
if alias.name == "*":
|
|
160
|
+
# Can't track wildcard imports precisely, so include the import text itself
|
|
161
|
+
imported_names.add(import_source)
|
|
162
|
+
else:
|
|
163
|
+
name = alias.asname if alias.asname else alias.name
|
|
164
|
+
imported_names.add(name)
|
|
165
|
+
return imported_names
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _parse_notebook_cells(cells: list[tuple[int, int, str]]):
|
|
169
|
+
"""Parse notebook cells and extract definitions of functions, classes, globals, and imports."""
|
|
170
|
+
import ast
|
|
171
|
+
|
|
172
|
+
latest_function_def: dict[str, tuple[str, ast.AST]] = {} # name -> (source, ast_node)
|
|
173
|
+
latest_global_assign: dict[str, str] = {} # name -> source
|
|
174
|
+
latest_class_def: dict[str, tuple[str, ast.AST]] = {} # name -> (source, ast_node)
|
|
175
|
+
all_imports: dict[str, tuple[list[str], ast.AST]] = {} # import_text -> (names_imported, ast_node)
|
|
176
|
+
|
|
177
|
+
for _, _, cell_source in cells:
|
|
178
|
+
cell_source = cell_source.strip()
|
|
179
|
+
if not cell_source:
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
cell_tree = ast.parse(cell_source)
|
|
184
|
+
except SyntaxError:
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
for node in cell_tree.body:
|
|
188
|
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
|
189
|
+
import_source = ast.get_source_segment(cell_source, node)
|
|
190
|
+
if import_source is None:
|
|
191
|
+
continue
|
|
192
|
+
imported_names = _get_import_names(node, cell_source, import_source)
|
|
193
|
+
all_imports[import_source] = (list(imported_names), node)
|
|
194
|
+
|
|
195
|
+
elif isinstance(node, ast.FunctionDef):
|
|
196
|
+
func_source = ast.get_source_segment(cell_source, node)
|
|
197
|
+
if func_source is not None:
|
|
198
|
+
latest_function_def[node.name] = (func_source, node)
|
|
199
|
+
|
|
200
|
+
elif isinstance(node, ast.ClassDef):
|
|
201
|
+
class_source = ast.get_source_segment(cell_source, node)
|
|
202
|
+
if class_source is not None:
|
|
203
|
+
latest_class_def[node.name] = (class_source, node)
|
|
204
|
+
|
|
205
|
+
elif isinstance(node, ast.Assign):
|
|
206
|
+
for target in node.targets:
|
|
207
|
+
if isinstance(target, ast.Name):
|
|
208
|
+
assign_source = ast.get_source_segment(cell_source, node)
|
|
209
|
+
if assign_source is not None:
|
|
210
|
+
latest_global_assign[target.id] = assign_source
|
|
211
|
+
|
|
212
|
+
return latest_function_def, latest_class_def, latest_global_assign, all_imports
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _get_referenced_names(source_code: str) -> set[str]:
|
|
216
|
+
"""Extract all names referenced in source code."""
|
|
217
|
+
import ast
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
tree = ast.parse(source_code)
|
|
221
|
+
except SyntaxError:
|
|
222
|
+
return set()
|
|
223
|
+
|
|
224
|
+
names = set()
|
|
225
|
+
for node in ast.walk(tree):
|
|
226
|
+
if isinstance(node, ast.Name):
|
|
227
|
+
names.add(node.id)
|
|
228
|
+
elif isinstance(node, ast.Attribute):
|
|
229
|
+
# For module.function, capture the base module
|
|
230
|
+
if isinstance(node.value, ast.Name):
|
|
231
|
+
names.add(node.value.id)
|
|
232
|
+
return names
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _collect_dependencies(
|
|
236
|
+
fn_source: str,
|
|
237
|
+
fn_name: str,
|
|
238
|
+
latest_function_def: dict[str, tuple[str, ast.AST]],
|
|
239
|
+
latest_class_def: dict[str, tuple[str, ast.AST]],
|
|
240
|
+
latest_global_assign: dict[str, str],
|
|
241
|
+
builtin_names: set[str],
|
|
242
|
+
):
|
|
243
|
+
"""Recursively collect all dependencies needed by the function."""
|
|
244
|
+
# maps name -> source
|
|
245
|
+
needed_functions: dict[str, str] = {}
|
|
246
|
+
needed_classes: dict[str, str] = {}
|
|
247
|
+
needed_globals: dict[str, str] = {}
|
|
248
|
+
needed_names: set[str] = set()
|
|
249
|
+
|
|
250
|
+
to_process = [fn_source]
|
|
251
|
+
processed = set()
|
|
252
|
+
|
|
253
|
+
while to_process:
|
|
254
|
+
current_source = to_process.pop()
|
|
255
|
+
if current_source in processed:
|
|
256
|
+
continue
|
|
257
|
+
processed.add(current_source)
|
|
258
|
+
|
|
259
|
+
referenced = _get_referenced_names(current_source)
|
|
260
|
+
referenced = referenced - builtin_names - {fn_name}
|
|
261
|
+
needed_names.update(referenced)
|
|
262
|
+
|
|
263
|
+
for name in referenced:
|
|
264
|
+
# Check if it's a class we defined
|
|
265
|
+
if name in latest_class_def and name not in needed_classes:
|
|
266
|
+
class_source, _ = latest_class_def[name]
|
|
267
|
+
needed_classes[name] = class_source
|
|
268
|
+
to_process.append(class_source)
|
|
269
|
+
|
|
270
|
+
# Check if it's a function we defined
|
|
271
|
+
elif name in latest_function_def and name not in needed_functions:
|
|
272
|
+
func_source, _ = latest_function_def[name]
|
|
273
|
+
needed_functions[name] = func_source
|
|
274
|
+
to_process.append(func_source)
|
|
275
|
+
|
|
276
|
+
for name in referenced:
|
|
277
|
+
# Check if it's a global variable we defined
|
|
278
|
+
if name in latest_global_assign and name not in needed_globals:
|
|
279
|
+
assign_source = latest_global_assign[name]
|
|
280
|
+
needed_globals[name] = assign_source
|
|
281
|
+
to_process.append(assign_source)
|
|
282
|
+
|
|
283
|
+
return needed_functions, needed_classes, needed_globals, needed_names
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _filter_imports(all_imports: dict[str, tuple[list[str], ast.AST]], needed_names: set[str]) -> list[str]:
|
|
287
|
+
"""Filter imports to only include those that are actually used."""
|
|
288
|
+
needed_imports: list[str] = []
|
|
289
|
+
for import_text, (imported_names, _) in all_imports.items():
|
|
290
|
+
if any(name in needed_names or name == import_text for name in imported_names):
|
|
291
|
+
needed_imports.append(import_text)
|
|
292
|
+
return needed_imports
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _build_script(
|
|
296
|
+
fn_source: str,
|
|
297
|
+
fn_name: str,
|
|
298
|
+
needed_imports: list[str],
|
|
299
|
+
needed_globals: dict[str, str],
|
|
300
|
+
needed_classes: dict[str, str],
|
|
301
|
+
needed_functions: dict[str, str],
|
|
302
|
+
) -> str:
|
|
303
|
+
"""Build the final script from collected components."""
|
|
304
|
+
script_parts: list[str] = []
|
|
305
|
+
|
|
306
|
+
if needed_imports:
|
|
307
|
+
script_parts.extend(needed_imports)
|
|
308
|
+
script_parts.append("")
|
|
309
|
+
|
|
310
|
+
if needed_globals:
|
|
311
|
+
script_parts.extend(needed_globals.values())
|
|
312
|
+
script_parts.append("")
|
|
313
|
+
|
|
314
|
+
if needed_classes:
|
|
315
|
+
script_parts.extend(needed_classes.values())
|
|
316
|
+
script_parts.append("")
|
|
317
|
+
|
|
318
|
+
if needed_functions:
|
|
319
|
+
script_parts.extend(needed_functions.values())
|
|
320
|
+
script_parts.append("")
|
|
321
|
+
|
|
322
|
+
script_parts.append(fn_source)
|
|
323
|
+
|
|
324
|
+
return "\n".join(script_parts)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def parse_notebook_into_script(fn: Callable[[], None], takes_argument: bool) -> str:
|
|
328
|
+
"""
|
|
329
|
+
Parse a notebook function and its dependencies into a standalone Python script.
|
|
330
|
+
|
|
331
|
+
The function must take no inputs and produce no outputs. The output script will
|
|
332
|
+
call fn() in __main__ and include all necessary imports, globals, and helper
|
|
333
|
+
functions that have been executed in the notebook.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
fn (Callable[[], None]): A callable with no parameters and no return value.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
str: A Python script as a string.
|
|
340
|
+
"""
|
|
341
|
+
import builtins
|
|
342
|
+
|
|
343
|
+
if not is_notebook():
|
|
344
|
+
raise RuntimeError("parse_notebook_into_script should only be called from a notebook environment.")
|
|
345
|
+
|
|
346
|
+
sig = inspect.signature(fn)
|
|
347
|
+
if len(sig.parameters) != int(takes_argument):
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Function {fn.__name__} must take {int(takes_argument)} inputs, but has parameters: {list(sig.parameters.keys())}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
shell = get_ipython_or_none()
|
|
353
|
+
if shell is None:
|
|
354
|
+
raise RuntimeError("Could not access IPython shell")
|
|
355
|
+
|
|
356
|
+
# Get the cell contents of executed cells
|
|
357
|
+
if getattr(shell, "history_manager", None) is None:
|
|
358
|
+
raise RuntimeError("Could not access IPython history manager")
|
|
359
|
+
|
|
360
|
+
history_manager = shell.history_manager
|
|
361
|
+
session_number = history_manager.get_last_session_id()
|
|
362
|
+
cells = list(history_manager.get_range(session=session_number, start=1))
|
|
363
|
+
|
|
364
|
+
# Parse cells to extract definitions
|
|
365
|
+
latest_function_def, latest_class_def, latest_global_assign, all_imports = _parse_notebook_cells(cells)
|
|
366
|
+
|
|
367
|
+
# Get function source and collect dependencies
|
|
368
|
+
fn_source = inspect.getsource(fn)
|
|
369
|
+
builtin_names = set(dir(builtins))
|
|
370
|
+
|
|
371
|
+
needed_functions, needed_classes, needed_globals, needed_names = _collect_dependencies(
|
|
372
|
+
fn_source, fn.__name__, latest_function_def, latest_class_def, latest_global_assign, builtin_names
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Filter imports to only used ones
|
|
376
|
+
needed_imports = _filter_imports(all_imports, needed_names)
|
|
377
|
+
|
|
378
|
+
# Build and return the script
|
|
379
|
+
script = _build_script(fn_source, fn.__name__, needed_imports, needed_globals, needed_classes, needed_functions)
|
|
380
|
+
|
|
381
|
+
if not is_valid_python_code(script):
|
|
382
|
+
raise RuntimeError("Error generating valid training function from notebook")
|
|
383
|
+
|
|
384
|
+
return script
|
chalk/utils/pl_helpers.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
|
+
import zoneinfo
|
|
4
5
|
from datetime import timedelta
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Iterator, TypeVar
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Iterator, TypeGuard, TypeVar, overload
|
|
6
7
|
|
|
7
8
|
import pyarrow as pa
|
|
8
|
-
import zoneinfo
|
|
9
9
|
from packaging.version import parse
|
|
10
|
-
from typing_extensions import TypeGuard
|
|
11
10
|
|
|
12
11
|
from chalk.utils.log_with_context import get_logger
|
|
13
12
|
from chalk.utils.missing_dependency import missing_dependency_exception
|
|
@@ -27,6 +26,13 @@ except ImportError:
|
|
|
27
26
|
json_loads = json.loads
|
|
28
27
|
|
|
29
28
|
|
|
29
|
+
def json_loads_as_str(x: str | None):
|
|
30
|
+
if x is None:
|
|
31
|
+
return None
|
|
32
|
+
x = json_loads(x)
|
|
33
|
+
return x if x is None else str(x)
|
|
34
|
+
|
|
35
|
+
|
|
30
36
|
def is_version_gte(version: str, target: str) -> bool:
|
|
31
37
|
return parse(version) >= parse(target)
|
|
32
38
|
|
|
@@ -36,9 +42,46 @@ try:
|
|
|
36
42
|
|
|
37
43
|
is_new_polars = is_version_gte(pl.__version__, "0.18.0")
|
|
38
44
|
polars_has_pad_start = is_version_gte(pl.__version__, "0.19.12")
|
|
45
|
+
polars_array_uses_shape = is_version_gte(pl.__version__, "1.0.0")
|
|
46
|
+
polars_uses_schema_overrides = is_version_gte(pl.__version__, "0.20.31")
|
|
47
|
+
polars_join_ignores_nulls = is_version_gte(pl.__version__, "0.20.0")
|
|
48
|
+
polars_broken_concat_on_nested_list = is_version_gte(pl.__version__, "1.0.0")
|
|
49
|
+
polars_group_by_instead_of_groupby = is_version_gte(pl.__version__, "1.0.0")
|
|
50
|
+
polars_name_dot_suffix_instead_of_suffix = is_version_gte(pl.__version__, "1.0.0")
|
|
51
|
+
polars_lazy_frame_collect_schema = is_version_gte(pl.__version__, "1.0.0")
|
|
52
|
+
polars_allow_lit_empty_struct = is_version_gte(pl.__version__, "1.0.0")
|
|
39
53
|
except ImportError:
|
|
40
54
|
is_new_polars = False
|
|
41
55
|
polars_has_pad_start = False
|
|
56
|
+
polars_array_uses_shape = False
|
|
57
|
+
polars_uses_schema_overrides = False
|
|
58
|
+
polars_join_ignores_nulls = False
|
|
59
|
+
polars_broken_concat_on_nested_list = False
|
|
60
|
+
polars_group_by_instead_of_groupby = False
|
|
61
|
+
polars_name_dot_suffix_instead_of_suffix = False
|
|
62
|
+
polars_lazy_frame_collect_schema = False
|
|
63
|
+
polars_allow_lit_empty_struct = False
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def pl_array(inner: pl.PolarsDataType, size: int) -> pl.Array:
|
|
67
|
+
"""Create a Polars Array type with version-compatible parameter names.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
inner: The inner data type of the array
|
|
71
|
+
size: The fixed size of the array
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A Polars Array type
|
|
75
|
+
"""
|
|
76
|
+
try:
|
|
77
|
+
import polars as pl
|
|
78
|
+
except ImportError:
|
|
79
|
+
raise missing_dependency_exception("chalkpy[runtime]")
|
|
80
|
+
|
|
81
|
+
if polars_array_uses_shape:
|
|
82
|
+
return pl.Array(inner=inner, shape=size)
|
|
83
|
+
else:
|
|
84
|
+
return pl.Array(inner=inner, width=size) # type: ignore[call-arg]
|
|
42
85
|
|
|
43
86
|
|
|
44
87
|
def chunked_df_slices(df: pl.LazyFrame | pl.DataFrame, chunk_size: int) -> Iterator[pl.DataFrame]:
|
|
@@ -100,13 +143,13 @@ def pl_datetime_to_iso_string(expr: pl.Expr, tz_key: str | None) -> pl.Expr:
|
|
|
100
143
|
else:
|
|
101
144
|
return pl.format(
|
|
102
145
|
"{}-{}-{}T{}:{}:{}.{}" + timezone,
|
|
103
|
-
expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"),
|
|
104
|
-
expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"),
|
|
105
|
-
expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"),
|
|
106
|
-
expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"),
|
|
107
|
-
expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"),
|
|
108
|
-
expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"),
|
|
109
|
-
expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"),
|
|
146
|
+
expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"), # pyright: ignore -- polars backcompat
|
|
147
|
+
expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
148
|
+
expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
149
|
+
expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
150
|
+
expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
151
|
+
expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
152
|
+
expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"), # pyright: ignore -- polars backcompat
|
|
110
153
|
)
|
|
111
154
|
|
|
112
155
|
|
|
@@ -126,9 +169,9 @@ def pl_date_to_iso_string(expr: pl.Expr) -> pl.Expr:
|
|
|
126
169
|
else:
|
|
127
170
|
return pl.format(
|
|
128
171
|
"{}-{}-{}",
|
|
129
|
-
expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"),
|
|
130
|
-
expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"),
|
|
131
|
-
expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"),
|
|
172
|
+
expr.dt.year().cast(pl.Utf8).str.rjust(4, "0"), # pyright: ignore -- polars backcompat
|
|
173
|
+
expr.dt.month().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
174
|
+
expr.dt.day().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
132
175
|
)
|
|
133
176
|
|
|
134
177
|
|
|
@@ -149,21 +192,39 @@ def pl_time_to_iso_string(expr: pl.Expr) -> pl.Expr:
|
|
|
149
192
|
else:
|
|
150
193
|
return pl.format(
|
|
151
194
|
"{}:{}:{}.{}",
|
|
152
|
-
expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"),
|
|
153
|
-
expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"),
|
|
154
|
-
expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"),
|
|
155
|
-
expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"),
|
|
195
|
+
expr.dt.hour().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
196
|
+
expr.dt.minute().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
197
|
+
expr.dt.second().cast(pl.Utf8).str.rjust(2, "0"), # pyright: ignore -- polars backcompat
|
|
198
|
+
expr.dt.microsecond().cast(pl.Utf8).str.rjust(6, "0"), # pyright: ignore -- polars backcompat
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def pl_dtype_swap(dtype: pl.PolarsDataType, _from: pl.PolarsDataType, to: pl.PolarsDataType) -> pl.PolarsDataType:
|
|
203
|
+
if isinstance(dtype, _from):
|
|
204
|
+
return to
|
|
205
|
+
if isinstance(dtype, pl.List):
|
|
206
|
+
return pl.List(inner=pl_dtype_swap(dtype.inner, _from, to))
|
|
207
|
+
if isinstance(dtype, pl.Struct):
|
|
208
|
+
return pl.Struct(
|
|
209
|
+
{field_name: pl_dtype_swap(field_dtype, _from, to) for field_name, field_dtype in dtype.to_schema().items()}
|
|
156
210
|
)
|
|
211
|
+
return dtype
|
|
157
212
|
|
|
158
213
|
|
|
159
|
-
def pl_json_decode(series: pl.Series, dtype: pl.PolarsDataType
|
|
214
|
+
def pl_json_decode(series: pl.Series, dtype: pl.PolarsDataType) -> pl.Series:
|
|
160
215
|
if is_new_polars:
|
|
161
|
-
|
|
216
|
+
swapped_dtype = pl_dtype_swap(dtype, pl.Binary, pl.Utf8)
|
|
217
|
+
if swapped_dtype == pl.Utf8:
|
|
218
|
+
decoded_series = series.map_elements(json_loads_as_str, return_dtype=swapped_dtype).cast(
|
|
219
|
+
dtype
|
|
220
|
+
) # pyright: ignore -- polars backcompat
|
|
221
|
+
else:
|
|
222
|
+
decoded_series = series.map_elements(json_loads, return_dtype=swapped_dtype).cast(
|
|
223
|
+
dtype
|
|
224
|
+
) # pyright: ignore -- polars backcompat
|
|
162
225
|
else:
|
|
163
|
-
decoded_series = series.apply(json_loads, return_dtype=dtype)
|
|
164
|
-
|
|
165
|
-
# Special case -- for nested dtypes polars doesn't always respect the return_dtype
|
|
166
|
-
decoded_series = decoded_series.cast(dtype)
|
|
226
|
+
decoded_series = series.apply(json_loads, return_dtype=dtype) # pyright: ignore -- polars backcompat
|
|
227
|
+
decoded_series = decoded_series.cast(dtype)
|
|
167
228
|
return decoded_series
|
|
168
229
|
|
|
169
230
|
|
|
@@ -174,19 +235,33 @@ def pl_duration_to_iso_string(expr: pl.Expr) -> pl.Expr:
|
|
|
174
235
|
except ImportError:
|
|
175
236
|
raise missing_dependency_exception("chalkpy[runtime]")
|
|
176
237
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
238
|
+
try:
|
|
239
|
+
return pl.format(
|
|
240
|
+
"{}P{}DT{}H{}M{}.{}S",
|
|
241
|
+
pl.when(expr.dt.microseconds() < 0) # pyright: ignore -- polars backcompat
|
|
242
|
+
.then(pl.lit("-"))
|
|
243
|
+
.otherwise(pl.lit("")), # pyright: ignore -- polars backcompat
|
|
244
|
+
expr.dt.days().abs().cast(pl.Utf8), # pyright: ignore -- polars backcompat
|
|
245
|
+
(expr.dt.hours().abs() % 24).cast(pl.Utf8), # pyright: ignore -- polars backcompat
|
|
246
|
+
(expr.dt.minutes().abs() % 60).cast(pl.Utf8), # pyright: ignore -- polars backcompat
|
|
247
|
+
(expr.dt.seconds().abs() % 60).cast(pl.Utf8), # pyright: ignore -- polars backcompat
|
|
248
|
+
(expr.dt.microseconds().abs() % 1_000_000) # pyright: ignore -- polars backcompat
|
|
249
|
+
.cast(pl.Utf8)
|
|
250
|
+
.str.pad_start(6, "0") # pyright: ignore -- polars backcompat
|
|
251
|
+
if is_new_polars
|
|
252
|
+
else (expr.dt.microseconds().abs() % 1_000_000) # pyright: ignore -- polars backcompat
|
|
253
|
+
.cast(pl.Utf8)
|
|
254
|
+
.str.rjust(6, "0"), # pyright: ignore -- polars backcompat
|
|
255
|
+
)
|
|
256
|
+
except AttributeError:
|
|
257
|
+
return (
|
|
258
|
+
pl.format("{}P{}DT{}H{}M{}.{}S", expr.dt.total_microseconds().abs() % 1_000_000)
|
|
259
|
+
.cast(pl.Utf8)
|
|
260
|
+
.str.pad_start(
|
|
261
|
+
6,
|
|
262
|
+
"0",
|
|
263
|
+
)
|
|
264
|
+
)
|
|
190
265
|
|
|
191
266
|
|
|
192
267
|
def pl_json_encode(expr: pl.Expr, dtype: pl.PolarsDataType):
|
|
@@ -374,7 +449,7 @@ def _json_encode_inner(expr: pl.Expr, dtype: pl.PolarsDataType) -> pl.Expr:
|
|
|
374
449
|
_backup_json_encode, return_dtype=pl.Utf8
|
|
375
450
|
)
|
|
376
451
|
else:
|
|
377
|
-
return expr.apply(_backup_json_encode, return_dtype=pl.Utf8)
|
|
452
|
+
return expr.apply(_backup_json_encode, return_dtype=pl.Utf8) # pyright: ignore -- polars backcompat
|
|
378
453
|
expr = expr.fill_null([])
|
|
379
454
|
lists_with_extra_none = (
|
|
380
455
|
expr.list.concat(pl.lit(None)) # pyright: ignore -- back compat
|
|
@@ -469,3 +544,81 @@ def recursively_has_struct(dtype: pa.DataType) -> bool:
|
|
|
469
544
|
assert isinstance(dtype, pa.MapType)
|
|
470
545
|
return recursively_has_struct(dtype.key_type) or recursively_has_struct(dtype.item_type)
|
|
471
546
|
return False
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def apply_compat(
|
|
550
|
+
expr: "pl.Expr",
|
|
551
|
+
function: Any,
|
|
552
|
+
return_dtype: "pl.PolarsDataType | None" = None,
|
|
553
|
+
**kwargs: Any,
|
|
554
|
+
) -> "pl.Expr":
|
|
555
|
+
"""
|
|
556
|
+
Apply a custom function to an expression in a version-compatible way.
|
|
557
|
+
|
|
558
|
+
In Polars >= 0.19, expr.apply() was deprecated in favor of expr.map_elements().
|
|
559
|
+
This function provides compatibility between versions.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
expr: The Polars expression to apply the function to
|
|
563
|
+
function: The function to apply to each element
|
|
564
|
+
return_dtype: The return data type for the expression (optional)
|
|
565
|
+
**kwargs: Additional keyword arguments to pass to the underlying method
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
A Polars expression with the function applied
|
|
569
|
+
|
|
570
|
+
Example:
|
|
571
|
+
>>> import polars as pl
|
|
572
|
+
>>> from chalkengine.utils.polars_compat_util import apply_compat
|
|
573
|
+
>>> df = pl.DataFrame({"a": [1, 2, 3]})
|
|
574
|
+
>>> df.select(apply_compat(pl.col("a"), lambda x: x * 2))
|
|
575
|
+
"""
|
|
576
|
+
# Build kwargs for the call
|
|
577
|
+
call_kwargs = kwargs.copy()
|
|
578
|
+
if return_dtype is not None:
|
|
579
|
+
call_kwargs["return_dtype"] = return_dtype
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
# Try newer API first: map_elements()
|
|
583
|
+
return expr.map_elements(function, **call_kwargs) # type: ignore
|
|
584
|
+
except AttributeError:
|
|
585
|
+
# Fall back to older API: apply()
|
|
586
|
+
return expr.apply(function, **call_kwargs) # type: ignore
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@overload
|
|
590
|
+
def str_json_decode_compat(expr: "pl.Expr", dtype: "pl.PolarsDataType") -> "pl.Expr":
|
|
591
|
+
...
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
@overload
|
|
595
|
+
def str_json_decode_compat(expr: "pl.Series", dtype: "pl.PolarsDataType") -> "pl.Series":
|
|
596
|
+
...
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def str_json_decode_compat(expr: "pl.Expr | pl.Series", dtype: "pl.PolarsDataType") -> "pl.Expr | pl.Series":
|
|
600
|
+
"""
|
|
601
|
+
Parse/decode JSON strings in a version-compatible way.
|
|
602
|
+
|
|
603
|
+
In newer Polars versions (>= 1.0), str.json_extract() was renamed to str.json_decode().
|
|
604
|
+
This function provides compatibility between versions.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
expr: The Polars expression containing JSON strings to parse
|
|
608
|
+
dtype: The Polars data type to extract to
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
A Polars expression that parses the JSON strings
|
|
612
|
+
"""
|
|
613
|
+
try:
|
|
614
|
+
# Try newer API first: str.json_decode()
|
|
615
|
+
return expr.str.json_decode(dtype=dtype) # type: ignore
|
|
616
|
+
except AttributeError:
|
|
617
|
+
# Fall back to older API: str.json_extract()
|
|
618
|
+
return expr.str.json_extract(dtype=dtype) # type: ignore
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def schema_compat(df: "pl.DataFrame | pl.LazyFrame"):
|
|
622
|
+
if polars_lazy_frame_collect_schema and isinstance(df, pl.LazyFrame):
|
|
623
|
+
return df.collect_schema()
|
|
624
|
+
return df.schema
|
|
@@ -2,12 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
from inspect import isclass
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any, TypeGuard
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
from packaging import version
|
|
9
9
|
from pydantic import BaseModel
|
|
10
|
-
from typing_extensions import TypeGuard
|
|
11
10
|
|
|
12
11
|
try:
|
|
13
12
|
from pydantic.v1 import BaseModel as V1BaseModel
|