dataproc-spark-connect 0.8.3__py2.py3-none-any.whl → 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataproc_spark_connect-1.0.0.dist-info/METADATA +200 -0
- dataproc_spark_connect-1.0.0.dist-info/RECORD +13 -0
- google/cloud/dataproc_spark_connect/client/core.py +5 -3
- google/cloud/dataproc_spark_connect/environment.py +101 -0
- google/cloud/dataproc_spark_connect/exceptions.py +1 -1
- google/cloud/dataproc_spark_connect/session.py +644 -76
- dataproc_spark_connect-0.8.3.dist-info/METADATA +0 -105
- dataproc_spark_connect-0.8.3.dist-info/RECORD +0 -12
- {dataproc_spark_connect-0.8.3.dist-info → dataproc_spark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {dataproc_spark_connect-0.8.3.dist-info → dataproc_spark_connect-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {dataproc_spark_connect-0.8.3.dist-info → dataproc_spark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import atexit
|
|
16
16
|
import datetime
|
|
17
|
+
import functools
|
|
17
18
|
import json
|
|
18
19
|
import logging
|
|
19
20
|
import os
|
|
@@ -24,8 +25,9 @@ import threading
|
|
|
24
25
|
import time
|
|
25
26
|
import uuid
|
|
26
27
|
import tqdm
|
|
28
|
+
from packaging import version
|
|
27
29
|
from types import MethodType
|
|
28
|
-
from typing import Any, cast, ClassVar, Dict, Optional, Union
|
|
30
|
+
from typing import Any, cast, ClassVar, Dict, Iterable, Optional, Union
|
|
29
31
|
|
|
30
32
|
from google.api_core import retry
|
|
31
33
|
from google.api_core.client_options import ClientOptions
|
|
@@ -43,12 +45,14 @@ from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
|
|
|
43
45
|
from google.cloud.dataproc_v1 import (
|
|
44
46
|
AuthenticationConfig,
|
|
45
47
|
CreateSessionRequest,
|
|
48
|
+
DeleteSessionRequest,
|
|
46
49
|
GetSessionRequest,
|
|
47
50
|
Session,
|
|
48
51
|
SessionControllerClient,
|
|
49
52
|
TerminateSessionRequest,
|
|
50
53
|
)
|
|
51
54
|
from google.cloud.dataproc_v1.types import sessions
|
|
55
|
+
from google.cloud.dataproc_spark_connect import environment
|
|
52
56
|
from pyspark.sql.connect.session import SparkSession
|
|
53
57
|
from pyspark.sql.utils import to_str
|
|
54
58
|
|
|
@@ -56,6 +60,16 @@ from pyspark.sql.utils import to_str
|
|
|
56
60
|
logging.basicConfig(level=logging.INFO)
|
|
57
61
|
logger = logging.getLogger(__name__)
|
|
58
62
|
|
|
63
|
+
# System labels that should not be overridden by user
|
|
64
|
+
SYSTEM_LABELS = {
|
|
65
|
+
"dataproc-session-client",
|
|
66
|
+
"goog-colab-notebook-id",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
_DATAPROC_SESSIONS_BASE_URL = (
|
|
70
|
+
"https://console.cloud.google.com/dataproc/interactive"
|
|
71
|
+
)
|
|
72
|
+
|
|
59
73
|
|
|
60
74
|
def _is_valid_label_value(value: str) -> bool:
|
|
61
75
|
"""
|
|
@@ -77,6 +91,22 @@ def _is_valid_label_value(value: str) -> bool:
|
|
|
77
91
|
return bool(re.match(pattern, value))
|
|
78
92
|
|
|
79
93
|
|
|
94
|
+
def _is_valid_session_id(session_id: str) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Validates if a string complies with Google Cloud session ID format.
|
|
97
|
+
- Must be 4-63 characters
|
|
98
|
+
- Only lowercase letters, numbers, and dashes are allowed
|
|
99
|
+
- Must start with a lowercase letter
|
|
100
|
+
- Cannot end with a dash
|
|
101
|
+
"""
|
|
102
|
+
if not session_id:
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
# The pattern is sufficient for validation and already enforces length constraints.
|
|
106
|
+
pattern = r"^[a-z][a-z0-9-]{2,61}[a-z0-9]$"
|
|
107
|
+
return bool(re.match(pattern, session_id))
|
|
108
|
+
|
|
109
|
+
|
|
80
110
|
class DataprocSparkSession(SparkSession):
|
|
81
111
|
"""The entry point to programming Spark with the Dataset and DataFrame API.
|
|
82
112
|
|
|
@@ -96,13 +126,16 @@ class DataprocSparkSession(SparkSession):
|
|
|
96
126
|
... ) # doctest: +SKIP
|
|
97
127
|
"""
|
|
98
128
|
|
|
99
|
-
_DEFAULT_RUNTIME_VERSION = "
|
|
129
|
+
_DEFAULT_RUNTIME_VERSION = "3.0"
|
|
130
|
+
_MIN_RUNTIME_VERSION = "3.0"
|
|
100
131
|
|
|
101
132
|
_active_s8s_session_uuid: ClassVar[Optional[str]] = None
|
|
102
133
|
_project_id = None
|
|
103
134
|
_region = None
|
|
104
135
|
_client_options = None
|
|
105
136
|
_active_s8s_session_id: ClassVar[Optional[str]] = None
|
|
137
|
+
_active_session_uses_custom_id: ClassVar[bool] = False
|
|
138
|
+
_execution_progress_bar = dict()
|
|
106
139
|
|
|
107
140
|
class Builder(SparkSession.Builder):
|
|
108
141
|
|
|
@@ -110,6 +143,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
110
143
|
self._options: Dict[str, Any] = {}
|
|
111
144
|
self._channel_builder: Optional[DataprocChannelBuilder] = None
|
|
112
145
|
self._dataproc_config: Optional[Session] = None
|
|
146
|
+
self._custom_session_id: Optional[str] = None
|
|
113
147
|
self._project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
114
148
|
self._region = os.getenv("GOOGLE_CLOUD_REGION")
|
|
115
149
|
self._client_options = ClientOptions(
|
|
@@ -118,6 +152,18 @@ class DataprocSparkSession(SparkSession):
|
|
|
118
152
|
f"{self._region}-dataproc.googleapis.com",
|
|
119
153
|
)
|
|
120
154
|
)
|
|
155
|
+
self._session_controller_client: Optional[
|
|
156
|
+
SessionControllerClient
|
|
157
|
+
] = None
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def session_controller_client(self) -> SessionControllerClient:
|
|
161
|
+
"""Get or create a SessionControllerClient instance."""
|
|
162
|
+
if self._session_controller_client is None:
|
|
163
|
+
self._session_controller_client = SessionControllerClient(
|
|
164
|
+
client_options=self._client_options
|
|
165
|
+
)
|
|
166
|
+
return self._session_controller_client
|
|
121
167
|
|
|
122
168
|
def projectId(self, project_id):
|
|
123
169
|
self._project_id = project_id
|
|
@@ -131,12 +177,106 @@ class DataprocSparkSession(SparkSession):
|
|
|
131
177
|
)
|
|
132
178
|
return self
|
|
133
179
|
|
|
180
|
+
def dataprocSessionId(self, session_id: str):
|
|
181
|
+
"""
|
|
182
|
+
Set a custom session ID for creating or reusing sessions.
|
|
183
|
+
|
|
184
|
+
The session ID must:
|
|
185
|
+
- Be 4-63 characters long
|
|
186
|
+
- Start with a lowercase letter
|
|
187
|
+
- Contain only lowercase letters, numbers, and hyphens
|
|
188
|
+
- Not end with a hyphen
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
session_id: The custom session ID to use
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
This Builder instance for method chaining
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If the session ID format is invalid
|
|
198
|
+
"""
|
|
199
|
+
if not _is_valid_session_id(session_id):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"Invalid session ID: '{session_id}'. "
|
|
202
|
+
"Session ID must be 4-63 characters, start with a lowercase letter, "
|
|
203
|
+
"contain only lowercase letters, numbers, and hyphens, "
|
|
204
|
+
"and not end with a hyphen."
|
|
205
|
+
)
|
|
206
|
+
self._custom_session_id = session_id
|
|
207
|
+
return self
|
|
208
|
+
|
|
134
209
|
def dataprocSessionConfig(self, dataproc_config: Session):
|
|
210
|
+
self._dataproc_config = dataproc_config
|
|
211
|
+
for k, v in dataproc_config.runtime_config.properties.items():
|
|
212
|
+
self._options[cast(str, k)] = to_str(v)
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def dataproc_config(self):
|
|
135
217
|
with self._lock:
|
|
136
|
-
self._dataproc_config =
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
218
|
+
self._dataproc_config = self._dataproc_config or Session()
|
|
219
|
+
return self._dataproc_config
|
|
220
|
+
|
|
221
|
+
def runtimeVersion(self, version: str):
|
|
222
|
+
self.dataproc_config.runtime_config.version = version
|
|
223
|
+
return self
|
|
224
|
+
|
|
225
|
+
def serviceAccount(self, account: str):
|
|
226
|
+
self.dataproc_config.environment_config.execution_config.service_account = (
|
|
227
|
+
account
|
|
228
|
+
)
|
|
229
|
+
return self
|
|
230
|
+
|
|
231
|
+
def subnetwork(self, subnet: str):
|
|
232
|
+
self.dataproc_config.environment_config.execution_config.subnetwork_uri = (
|
|
233
|
+
subnet
|
|
234
|
+
)
|
|
235
|
+
return self
|
|
236
|
+
|
|
237
|
+
def ttl(self, duration: datetime.timedelta):
|
|
238
|
+
"""Set the time-to-live (TTL) for the session using a timedelta object."""
|
|
239
|
+
return self.ttlSeconds(int(duration.total_seconds()))
|
|
240
|
+
|
|
241
|
+
def ttlSeconds(self, seconds: int):
|
|
242
|
+
"""Set the time-to-live (TTL) for the session in seconds."""
|
|
243
|
+
self.dataproc_config.environment_config.execution_config.ttl = {
|
|
244
|
+
"seconds": seconds
|
|
245
|
+
}
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def idleTtl(self, duration: datetime.timedelta):
|
|
249
|
+
"""Set the idle time-to-live (idle TTL) for the session using a timedelta object."""
|
|
250
|
+
return self.idleTtlSeconds(int(duration.total_seconds()))
|
|
251
|
+
|
|
252
|
+
def idleTtlSeconds(self, seconds: int):
|
|
253
|
+
"""Set the idle time-to-live (idle TTL) for the session in seconds."""
|
|
254
|
+
self.dataproc_config.environment_config.execution_config.idle_ttl = {
|
|
255
|
+
"seconds": seconds
|
|
256
|
+
}
|
|
257
|
+
return self
|
|
258
|
+
|
|
259
|
+
def sessionTemplate(self, template: str):
|
|
260
|
+
self.dataproc_config.session_template = template
|
|
261
|
+
return self
|
|
262
|
+
|
|
263
|
+
def label(self, key: str, value: str):
|
|
264
|
+
"""Add a single label to the session."""
|
|
265
|
+
return self.labels({key: value})
|
|
266
|
+
|
|
267
|
+
def labels(self, labels: Dict[str, str]):
|
|
268
|
+
# Filter out system labels and warn user
|
|
269
|
+
filtered_labels = {}
|
|
270
|
+
for key, value in labels.items():
|
|
271
|
+
if key in SYSTEM_LABELS:
|
|
272
|
+
logger.warning(
|
|
273
|
+
f"Label '{key}' is a system label and cannot be overridden by user. Ignoring."
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
filtered_labels[key] = value
|
|
277
|
+
|
|
278
|
+
self.dataproc_config.labels.update(filtered_labels)
|
|
279
|
+
return self
|
|
140
280
|
|
|
141
281
|
def remote(self, url: Optional[str] = None) -> "SparkSession.Builder":
|
|
142
282
|
if url:
|
|
@@ -175,7 +315,11 @@ class DataprocSparkSession(SparkSession):
|
|
|
175
315
|
assert self._channel_builder is not None
|
|
176
316
|
session = DataprocSparkSession(connection=self._channel_builder)
|
|
177
317
|
|
|
318
|
+
# Register handler for Cell Execution Progress bar
|
|
319
|
+
session._register_progress_execution_handler()
|
|
320
|
+
|
|
178
321
|
DataprocSparkSession._set_default_and_active_session(session)
|
|
322
|
+
|
|
179
323
|
return session
|
|
180
324
|
|
|
181
325
|
def __create(self) -> "DataprocSparkSession":
|
|
@@ -190,7 +334,16 @@ class DataprocSparkSession(SparkSession):
|
|
|
190
334
|
|
|
191
335
|
dataproc_config: Session = self._get_dataproc_config()
|
|
192
336
|
|
|
193
|
-
|
|
337
|
+
# Check runtime version compatibility before creating session
|
|
338
|
+
self._check_runtime_compatibility(dataproc_config)
|
|
339
|
+
|
|
340
|
+
# Use custom session ID if provided, otherwise generate one
|
|
341
|
+
session_id = (
|
|
342
|
+
self._custom_session_id
|
|
343
|
+
if self._custom_session_id
|
|
344
|
+
else self.generate_dataproc_session_id()
|
|
345
|
+
)
|
|
346
|
+
|
|
194
347
|
dataproc_config.name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
|
|
195
348
|
logger.debug(
|
|
196
349
|
f"Dataproc Session configuration:\n{dataproc_config}"
|
|
@@ -205,6 +358,10 @@ class DataprocSparkSession(SparkSession):
|
|
|
205
358
|
|
|
206
359
|
logger.debug("Creating Dataproc Session")
|
|
207
360
|
DataprocSparkSession._active_s8s_session_id = session_id
|
|
361
|
+
# Track whether this session uses a custom ID (unmanaged) or auto-generated ID (managed)
|
|
362
|
+
DataprocSparkSession._active_session_uses_custom_id = (
|
|
363
|
+
self._custom_session_id is not None
|
|
364
|
+
)
|
|
208
365
|
s8s_creation_start_time = time.time()
|
|
209
366
|
|
|
210
367
|
stop_create_session_pbar_event = threading.Event()
|
|
@@ -258,8 +415,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
258
415
|
client_options=self._client_options
|
|
259
416
|
).create_session(session_request)
|
|
260
417
|
self._display_session_link_on_creation(session_id)
|
|
261
|
-
|
|
262
|
-
# self._display_view_session_details_button(session_id)
|
|
418
|
+
self._display_view_session_details_button(session_id)
|
|
263
419
|
create_session_pbar_thread.start()
|
|
264
420
|
session_response: Session = operation.result(
|
|
265
421
|
polling=retry.Retry(
|
|
@@ -296,6 +452,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
296
452
|
if create_session_pbar_thread.is_alive():
|
|
297
453
|
create_session_pbar_thread.join()
|
|
298
454
|
DataprocSparkSession._active_s8s_session_id = None
|
|
455
|
+
DataprocSparkSession._active_session_uses_custom_id = False
|
|
299
456
|
raise DataprocSparkConnectException(
|
|
300
457
|
f"Error while creating Dataproc Session: {e.message}"
|
|
301
458
|
)
|
|
@@ -304,6 +461,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
304
461
|
if create_session_pbar_thread.is_alive():
|
|
305
462
|
create_session_pbar_thread.join()
|
|
306
463
|
DataprocSparkSession._active_s8s_session_id = None
|
|
464
|
+
DataprocSparkSession._active_session_uses_custom_id = False
|
|
307
465
|
raise RuntimeError(
|
|
308
466
|
f"Error while creating Dataproc Session"
|
|
309
467
|
) from e
|
|
@@ -317,16 +475,43 @@ class DataprocSparkSession(SparkSession):
|
|
|
317
475
|
session_response, dataproc_config.name
|
|
318
476
|
)
|
|
319
477
|
|
|
478
|
+
def _wait_for_session_available(
|
|
479
|
+
self, session_name: str, timeout: int = 300
|
|
480
|
+
) -> Session:
|
|
481
|
+
start_time = time.time()
|
|
482
|
+
while time.time() - start_time < timeout:
|
|
483
|
+
try:
|
|
484
|
+
session = self.session_controller_client.get_session(
|
|
485
|
+
name=session_name
|
|
486
|
+
)
|
|
487
|
+
if "Spark Connect Server" in session.runtime_info.endpoints:
|
|
488
|
+
return session
|
|
489
|
+
time.sleep(5)
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.warning(
|
|
492
|
+
f"Error while polling for Spark Connect endpoint: {e}"
|
|
493
|
+
)
|
|
494
|
+
time.sleep(5)
|
|
495
|
+
raise RuntimeError(
|
|
496
|
+
f"Spark Connect endpoint not available for session {session_name} after {timeout} seconds."
|
|
497
|
+
)
|
|
498
|
+
|
|
320
499
|
def _display_session_link_on_creation(self, session_id):
|
|
321
|
-
session_url = f"
|
|
500
|
+
session_url = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{session_id}?project={self._project_id}"
|
|
322
501
|
plain_message = f"Creating Dataproc Session: {session_url}"
|
|
323
|
-
|
|
502
|
+
if environment.is_colab_enterprise():
|
|
503
|
+
html_element = f"""
|
|
324
504
|
<div>
|
|
325
505
|
<p>Creating Dataproc Spark Session<p>
|
|
326
|
-
<p><a href="{session_url}">Dataproc Session</a></p>
|
|
327
506
|
</div>
|
|
328
|
-
|
|
329
|
-
|
|
507
|
+
"""
|
|
508
|
+
else:
|
|
509
|
+
html_element = f"""
|
|
510
|
+
<div>
|
|
511
|
+
<p>Creating Dataproc Spark Session<p>
|
|
512
|
+
<p><a href="{session_url}">Dataproc Session</a></p>
|
|
513
|
+
</div>
|
|
514
|
+
"""
|
|
330
515
|
self._output_element_or_message(plain_message, html_element)
|
|
331
516
|
|
|
332
517
|
def _print_session_created_message(self):
|
|
@@ -345,16 +530,19 @@ class DataprocSparkSession(SparkSession):
|
|
|
345
530
|
:param html_element: HTML element to display for interactive IPython
|
|
346
531
|
environment
|
|
347
532
|
"""
|
|
533
|
+
# Don't print any output (Rich or Plain) for non-interactive
|
|
534
|
+
if not environment.is_interactive():
|
|
535
|
+
return
|
|
536
|
+
|
|
537
|
+
if environment.is_interactive_terminal():
|
|
538
|
+
print(plain_message)
|
|
539
|
+
return
|
|
540
|
+
|
|
348
541
|
try:
|
|
349
542
|
from IPython.display import display, HTML
|
|
350
|
-
from IPython.core.interactiveshell import InteractiveShell
|
|
351
543
|
|
|
352
|
-
if not InteractiveShell.initialized():
|
|
353
|
-
raise DataprocSparkConnectException(
|
|
354
|
-
"Not in an Interactive IPython Environment"
|
|
355
|
-
)
|
|
356
544
|
display(HTML(html_element))
|
|
357
|
-
except
|
|
545
|
+
except ImportError:
|
|
358
546
|
print(plain_message)
|
|
359
547
|
|
|
360
548
|
def _get_exiting_active_session(
|
|
@@ -375,11 +563,13 @@ class DataprocSparkSession(SparkSession):
|
|
|
375
563
|
|
|
376
564
|
if session_response is not None:
|
|
377
565
|
print(
|
|
378
|
-
f"Using existing Dataproc Session (configuration changes may not be applied):
|
|
566
|
+
f"Using existing Dataproc Session (configuration changes may not be applied): {_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{s8s_session_id}?project={self._project_id}"
|
|
379
567
|
)
|
|
380
|
-
|
|
381
|
-
# self._display_view_session_details_button(s8s_session_id)
|
|
568
|
+
self._display_view_session_details_button(s8s_session_id)
|
|
382
569
|
if session is None:
|
|
570
|
+
session_response = self._wait_for_session_available(
|
|
571
|
+
session_name
|
|
572
|
+
)
|
|
383
573
|
session = self.__create_spark_connect_session_from_s8s(
|
|
384
574
|
session_response, session_name
|
|
385
575
|
)
|
|
@@ -395,17 +585,59 @@ class DataprocSparkSession(SparkSession):
|
|
|
395
585
|
|
|
396
586
|
def getOrCreate(self) -> "DataprocSparkSession":
|
|
397
587
|
with DataprocSparkSession._lock:
|
|
588
|
+
if environment.is_dataproc_batch():
|
|
589
|
+
# For Dataproc batch workloads, connect to the already initialized local SparkSession
|
|
590
|
+
from pyspark.sql import SparkSession as PySparkSQLSession
|
|
591
|
+
|
|
592
|
+
session = PySparkSQLSession.builder.getOrCreate()
|
|
593
|
+
return session # type: ignore
|
|
594
|
+
|
|
595
|
+
if self._project_id is None:
|
|
596
|
+
raise DataprocSparkConnectException(
|
|
597
|
+
f"Error while creating Dataproc Session: project ID is not set"
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
if self._region is None:
|
|
601
|
+
raise DataprocSparkConnectException(
|
|
602
|
+
f"Error while creating Dataproc Session: location is not set"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Handle custom session ID by setting it early and letting existing logic handle it
|
|
606
|
+
if self._custom_session_id:
|
|
607
|
+
self._handle_custom_session_id()
|
|
608
|
+
|
|
398
609
|
session = self._get_exiting_active_session()
|
|
399
610
|
if session is None:
|
|
400
611
|
session = self.__create()
|
|
612
|
+
|
|
613
|
+
# Register this session as the instantiated SparkSession for compatibility
|
|
614
|
+
# with tools and libraries that expect SparkSession._instantiatedSession
|
|
615
|
+
from pyspark.sql import SparkSession as PySparkSQLSession
|
|
616
|
+
|
|
617
|
+
PySparkSQLSession._instantiatedSession = session
|
|
618
|
+
|
|
401
619
|
return session
|
|
402
620
|
|
|
621
|
+
def _handle_custom_session_id(self):
|
|
622
|
+
"""Handle custom session ID by checking if it exists and setting _active_s8s_session_id."""
|
|
623
|
+
session_response = self._get_session_by_id(self._custom_session_id)
|
|
624
|
+
if session_response is not None:
|
|
625
|
+
# Found an active session with the custom ID, set it as the active session
|
|
626
|
+
DataprocSparkSession._active_s8s_session_id = (
|
|
627
|
+
self._custom_session_id
|
|
628
|
+
)
|
|
629
|
+
# Mark that this session uses a custom ID
|
|
630
|
+
DataprocSparkSession._active_session_uses_custom_id = True
|
|
631
|
+
else:
|
|
632
|
+
# No existing session found, clear any existing active session ID
|
|
633
|
+
# so we'll create a new one with the custom ID
|
|
634
|
+
DataprocSparkSession._active_s8s_session_id = None
|
|
635
|
+
|
|
403
636
|
def _get_dataproc_config(self):
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
dataproc_config.runtime_config.properties[k] = v
|
|
637
|
+
# Use the property to ensure we always have a config
|
|
638
|
+
dataproc_config = self.dataproc_config
|
|
639
|
+
for k, v in self._options.items():
|
|
640
|
+
dataproc_config.runtime_config.properties[k] = v
|
|
409
641
|
dataproc_config.spark_connect_session = (
|
|
410
642
|
sessions.SparkConnectConfig()
|
|
411
643
|
)
|
|
@@ -413,20 +645,38 @@ class DataprocSparkSession(SparkSession):
|
|
|
413
645
|
dataproc_config.runtime_config.version = (
|
|
414
646
|
DataprocSparkSession._DEFAULT_RUNTIME_VERSION
|
|
415
647
|
)
|
|
648
|
+
|
|
649
|
+
# Check for Python version mismatch with runtime for UDF compatibility
|
|
650
|
+
self._check_python_version_compatibility(
|
|
651
|
+
dataproc_config.runtime_config.version
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Use local variable to improve readability of deeply nested attribute access
|
|
655
|
+
exec_config = dataproc_config.environment_config.execution_config
|
|
656
|
+
|
|
657
|
+
# Set service account from environment if not already set
|
|
416
658
|
if (
|
|
417
|
-
not
|
|
418
|
-
and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
|
|
419
|
-
):
|
|
420
|
-
dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
|
|
421
|
-
os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
|
|
422
|
-
]
|
|
423
|
-
if (
|
|
424
|
-
not dataproc_config.environment_config.execution_config.service_account
|
|
659
|
+
not exec_config.service_account
|
|
425
660
|
and "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT" in os.environ
|
|
426
661
|
):
|
|
427
|
-
|
|
662
|
+
exec_config.service_account = os.getenv(
|
|
428
663
|
"DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT"
|
|
429
664
|
)
|
|
665
|
+
|
|
666
|
+
# Auto-set authentication type to SERVICE_ACCOUNT when service account is provided
|
|
667
|
+
if exec_config.service_account:
|
|
668
|
+
# When service account is provided, explicitly set auth type to SERVICE_ACCOUNT
|
|
669
|
+
exec_config.authentication_config.user_workload_authentication_type = (
|
|
670
|
+
AuthenticationConfig.AuthenticationType.SERVICE_ACCOUNT
|
|
671
|
+
)
|
|
672
|
+
elif (
|
|
673
|
+
not exec_config.authentication_config.user_workload_authentication_type
|
|
674
|
+
and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
|
|
675
|
+
):
|
|
676
|
+
# Only set auth type from environment if no service account is present
|
|
677
|
+
exec_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
|
|
678
|
+
os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
|
|
679
|
+
]
|
|
430
680
|
if (
|
|
431
681
|
not dataproc_config.environment_config.execution_config.subnetwork_uri
|
|
432
682
|
and "DATAPROC_SPARK_CONNECT_SUBNET" in os.environ
|
|
@@ -452,6 +702,10 @@ class DataprocSparkSession(SparkSession):
|
|
|
452
702
|
os.getenv("DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS")
|
|
453
703
|
)
|
|
454
704
|
}
|
|
705
|
+
client_environment = environment.get_client_environment_label()
|
|
706
|
+
dataproc_config.labels["dataproc-session-client"] = (
|
|
707
|
+
client_environment
|
|
708
|
+
)
|
|
455
709
|
if "COLAB_NOTEBOOK_ID" in os.environ:
|
|
456
710
|
colab_notebook_name = os.environ["COLAB_NOTEBOOK_ID"]
|
|
457
711
|
# Extract the last part of the path, which is the ID
|
|
@@ -466,37 +720,102 @@ class DataprocSparkSession(SparkSession):
|
|
|
466
720
|
f"Only lowercase letters, numbers, and dashes are allowed. "
|
|
467
721
|
f"The value must start with lowercase letter or number and end with a lowercase letter or number. "
|
|
468
722
|
f"Maximum length is 63 characters. "
|
|
469
|
-
f"
|
|
723
|
+
f"Ignoring notebook ID label."
|
|
470
724
|
)
|
|
471
725
|
default_datasource = os.getenv(
|
|
472
726
|
"DATAPROC_SPARK_CONNECT_DEFAULT_DATASOURCE"
|
|
473
727
|
)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
bq_datasource_properties = {
|
|
480
|
-
"spark.datasource.bigquery.viewsEnabled": "true",
|
|
481
|
-
"spark.datasource.bigquery.writeMethod": "direct",
|
|
728
|
+
match default_datasource:
|
|
729
|
+
case "bigquery":
|
|
730
|
+
# Merge default configs with existing properties,
|
|
731
|
+
# user configs take precedence
|
|
732
|
+
for k, v in {
|
|
482
733
|
"spark.sql.catalog.spark_catalog": "com.google.cloud.spark.bigquery.BigQuerySparkSessionCatalog",
|
|
483
|
-
"spark.sql.legacy.createHiveTableByDefault": "false",
|
|
484
734
|
"spark.sql.sources.default": "bigquery",
|
|
485
|
-
}
|
|
486
|
-
# Merge default configs with existing properties, user configs take precedence
|
|
487
|
-
for k, v in bq_datasource_properties.items():
|
|
735
|
+
}.items():
|
|
488
736
|
if k not in dataproc_config.runtime_config.properties:
|
|
489
737
|
dataproc_config.runtime_config.properties[k] = v
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
738
|
+
case _:
|
|
739
|
+
if default_datasource:
|
|
740
|
+
logger.warning(
|
|
741
|
+
f"DATAPROC_SPARK_CONNECT_DEFAULT_DATASOURCE is set to an invalid value:"
|
|
742
|
+
f" {default_datasource}. Supported value is 'bigquery'."
|
|
743
|
+
)
|
|
744
|
+
|
|
495
745
|
return dataproc_config
|
|
496
746
|
|
|
747
|
+
def _check_python_version_compatibility(self, runtime_version):
|
|
748
|
+
"""Check if client Python version matches server Python version for UDF compatibility."""
|
|
749
|
+
import sys
|
|
750
|
+
import warnings
|
|
751
|
+
|
|
752
|
+
# Runtime version to server Python version mapping
|
|
753
|
+
RUNTIME_PYTHON_MAP = {
|
|
754
|
+
"3.0": (3, 12),
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
client_python = sys.version_info[:2] # (major, minor)
|
|
758
|
+
|
|
759
|
+
if runtime_version in RUNTIME_PYTHON_MAP:
|
|
760
|
+
server_python = RUNTIME_PYTHON_MAP[runtime_version]
|
|
761
|
+
|
|
762
|
+
if client_python != server_python:
|
|
763
|
+
warnings.warn(
|
|
764
|
+
f"Python version mismatch detected: Client is using Python {client_python[0]}.{client_python[1]}, "
|
|
765
|
+
f"but Dataproc runtime {runtime_version} uses Python {server_python[0]}.{server_python[1]}. "
|
|
766
|
+
f"This mismatch may cause issues with Python UDF (User Defined Function) compatibility. "
|
|
767
|
+
f"Consider using Python {server_python[0]}.{server_python[1]} for optimal UDF execution.",
|
|
768
|
+
stacklevel=3,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
def _check_runtime_compatibility(self, dataproc_config):
|
|
772
|
+
"""Check if runtime version 3.0 client is compatible with older runtime versions.
|
|
773
|
+
|
|
774
|
+
Runtime version 3.0 clients do not support older runtime versions (pre-3.0).
|
|
775
|
+
There is no backward or forward compatibility between different runtime versions.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
dataproc_config: The Session configuration containing runtime version
|
|
779
|
+
|
|
780
|
+
Raises:
|
|
781
|
+
DataprocSparkConnectException: If server is using pre-3.0 runtime version
|
|
782
|
+
"""
|
|
783
|
+
runtime_version = dataproc_config.runtime_config.version
|
|
784
|
+
|
|
785
|
+
if not runtime_version:
|
|
786
|
+
return
|
|
787
|
+
|
|
788
|
+
logger.debug(f"Detected server runtime version: {runtime_version}")
|
|
789
|
+
|
|
790
|
+
# Parse runtime version to check if it's below minimum supported version
|
|
791
|
+
try:
|
|
792
|
+
server_version = version.parse(runtime_version)
|
|
793
|
+
min_version = version.parse(
|
|
794
|
+
DataprocSparkSession._MIN_RUNTIME_VERSION
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
if server_version < min_version:
|
|
798
|
+
raise DataprocSparkConnectException(
|
|
799
|
+
f"Specified {runtime_version} Dataproc Runtime version is not supported, "
|
|
800
|
+
f"use {DataprocSparkSession._MIN_RUNTIME_VERSION} version or higher."
|
|
801
|
+
)
|
|
802
|
+
except version.InvalidVersion:
|
|
803
|
+
# If we can't parse the version, log a warning but continue
|
|
804
|
+
logger.warning(
|
|
805
|
+
f"Could not parse runtime version: {runtime_version}"
|
|
806
|
+
)
|
|
807
|
+
|
|
497
808
|
def _display_view_session_details_button(self, session_id):
|
|
809
|
+
# Display button is only supported in colab enterprise
|
|
810
|
+
if not environment.is_colab_enterprise():
|
|
811
|
+
return
|
|
812
|
+
|
|
813
|
+
# Skip button display for colab enterprise IPython terminals
|
|
814
|
+
if environment.is_interactive_terminal():
|
|
815
|
+
return
|
|
816
|
+
|
|
498
817
|
try:
|
|
499
|
-
session_url = f"
|
|
818
|
+
session_url = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{session_id}?project={self._project_id}"
|
|
500
819
|
from IPython.core.interactiveshell import InteractiveShell
|
|
501
820
|
|
|
502
821
|
if not InteractiveShell.initialized():
|
|
@@ -510,6 +829,90 @@ class DataprocSparkSession(SparkSession):
|
|
|
510
829
|
except ImportError as e:
|
|
511
830
|
logger.debug(f"Import error: {e}")
|
|
512
831
|
|
|
832
|
+
def _get_session_by_id(self, session_id: str) -> Optional[Session]:
|
|
833
|
+
"""
|
|
834
|
+
Get existing session by ID.
|
|
835
|
+
|
|
836
|
+
Returns:
|
|
837
|
+
Session if ACTIVE/CREATING, None if not found or not usable
|
|
838
|
+
"""
|
|
839
|
+
session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
|
|
840
|
+
|
|
841
|
+
try:
|
|
842
|
+
get_request = GetSessionRequest(name=session_name)
|
|
843
|
+
session = self.session_controller_client.get_session(
|
|
844
|
+
get_request
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
logger.debug(
|
|
848
|
+
f"Found existing session {session_id} in state: {session.state}"
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
if session.state in [
|
|
852
|
+
Session.State.ACTIVE,
|
|
853
|
+
Session.State.CREATING,
|
|
854
|
+
]:
|
|
855
|
+
# Reuse the active session
|
|
856
|
+
logger.info(f"Reusing existing session: {session_id}")
|
|
857
|
+
return session
|
|
858
|
+
else:
|
|
859
|
+
# Session exists but is not usable (terminated/failed/terminating)
|
|
860
|
+
logger.info(
|
|
861
|
+
f"Session {session_id} in {session.state.name} state, cannot reuse"
|
|
862
|
+
)
|
|
863
|
+
return None
|
|
864
|
+
|
|
865
|
+
except NotFound:
|
|
866
|
+
# Session doesn't exist, can create new one
|
|
867
|
+
logger.debug(
|
|
868
|
+
f"Session {session_id} not found, can create new one"
|
|
869
|
+
)
|
|
870
|
+
return None
|
|
871
|
+
except Exception as e:
|
|
872
|
+
logger.error(f"Error checking session {session_id}: {e}")
|
|
873
|
+
return None
|
|
874
|
+
|
|
875
|
+
def _delete_session(self, session_name: str):
|
|
876
|
+
"""Delete a session to free up the session ID for reuse."""
|
|
877
|
+
try:
|
|
878
|
+
delete_request = DeleteSessionRequest(name=session_name)
|
|
879
|
+
self.session_controller_client.delete_session(delete_request)
|
|
880
|
+
logger.debug(f"Deleted session: {session_name}")
|
|
881
|
+
except NotFound:
|
|
882
|
+
logger.debug(f"Session already deleted: {session_name}")
|
|
883
|
+
|
|
884
|
+
def _wait_for_termination(self, session_name: str, timeout: int = 180):
|
|
885
|
+
"""Wait for a session to finish terminating."""
|
|
886
|
+
start_time = time.time()
|
|
887
|
+
|
|
888
|
+
while time.time() - start_time < timeout:
|
|
889
|
+
try:
|
|
890
|
+
get_request = GetSessionRequest(name=session_name)
|
|
891
|
+
session = self.session_controller_client.get_session(
|
|
892
|
+
get_request
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
if session.state in [
|
|
896
|
+
Session.State.TERMINATED,
|
|
897
|
+
Session.State.FAILED,
|
|
898
|
+
]:
|
|
899
|
+
return
|
|
900
|
+
elif session.state != Session.State.TERMINATING:
|
|
901
|
+
# Session is in unexpected state
|
|
902
|
+
logger.warning(
|
|
903
|
+
f"Session {session_name} in unexpected state while waiting for termination: {session.state}"
|
|
904
|
+
)
|
|
905
|
+
return
|
|
906
|
+
|
|
907
|
+
time.sleep(2)
|
|
908
|
+
except NotFound:
|
|
909
|
+
# Session was deleted
|
|
910
|
+
return
|
|
911
|
+
|
|
912
|
+
logger.warning(
|
|
913
|
+
f"Timeout waiting for session {session_name} to terminate"
|
|
914
|
+
)
|
|
915
|
+
|
|
513
916
|
@staticmethod
|
|
514
917
|
def generate_dataproc_session_id():
|
|
515
918
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
@@ -583,16 +986,111 @@ class DataprocSparkSession(SparkSession):
|
|
|
583
986
|
execute_and_fetch_as_iterator_wrapped_method, self.client
|
|
584
987
|
)
|
|
585
988
|
|
|
989
|
+
# Patching clearProgressHandlers method to not remove Dataproc Progress Handler
|
|
990
|
+
clearProgressHandlers_base_method = self.clearProgressHandlers
|
|
991
|
+
|
|
992
|
+
def clearProgressHandlers_wrapper_method(_, *args, **kwargs):
|
|
993
|
+
clearProgressHandlers_base_method(*args, **kwargs)
|
|
994
|
+
|
|
995
|
+
self._register_progress_execution_handler()
|
|
996
|
+
|
|
997
|
+
self.clearProgressHandlers = MethodType(
|
|
998
|
+
clearProgressHandlers_wrapper_method, self
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
@staticmethod
|
|
1002
|
+
@functools.lru_cache(maxsize=1)
|
|
1003
|
+
def get_tqdm_bar():
|
|
1004
|
+
"""
|
|
1005
|
+
Return a tqdm implementation that works in the current environment.
|
|
1006
|
+
|
|
1007
|
+
- Uses CLI tqdm for interactive terminals.
|
|
1008
|
+
- Uses the notebook tqdm if available, otherwise falls back to CLI tqdm.
|
|
1009
|
+
"""
|
|
1010
|
+
from tqdm import tqdm as cli_tqdm
|
|
1011
|
+
|
|
1012
|
+
if environment.is_interactive_terminal():
|
|
1013
|
+
return cli_tqdm
|
|
1014
|
+
|
|
1015
|
+
try:
|
|
1016
|
+
import ipywidgets
|
|
1017
|
+
from tqdm.notebook import tqdm as notebook_tqdm
|
|
1018
|
+
|
|
1019
|
+
return notebook_tqdm
|
|
1020
|
+
except ImportError:
|
|
1021
|
+
return cli_tqdm
|
|
1022
|
+
|
|
1023
|
+
def _register_progress_execution_handler(self):
|
|
1024
|
+
from pyspark.sql.connect.shell.progress import StageInfo
|
|
1025
|
+
|
|
1026
|
+
def handler(
|
|
1027
|
+
stages: Optional[Iterable[StageInfo]],
|
|
1028
|
+
inflight_tasks: int,
|
|
1029
|
+
operation_id: Optional[str],
|
|
1030
|
+
done: bool,
|
|
1031
|
+
):
|
|
1032
|
+
if operation_id is None:
|
|
1033
|
+
return
|
|
1034
|
+
|
|
1035
|
+
# Don't build / render progress bar for non-interactive (despite
|
|
1036
|
+
# Ipython or non-IPython)
|
|
1037
|
+
if not environment.is_interactive():
|
|
1038
|
+
return
|
|
1039
|
+
|
|
1040
|
+
total_tasks = 0
|
|
1041
|
+
completed_tasks = 0
|
|
1042
|
+
|
|
1043
|
+
for stage in stages or []:
|
|
1044
|
+
total_tasks += stage.num_tasks
|
|
1045
|
+
completed_tasks += stage.num_completed_tasks
|
|
1046
|
+
|
|
1047
|
+
# Don't show progress bar till we receive some tasks
|
|
1048
|
+
if total_tasks == 0:
|
|
1049
|
+
return
|
|
1050
|
+
|
|
1051
|
+
# Get correct tqdm (notebook or CLI)
|
|
1052
|
+
tqdm_pbar = self.get_tqdm_bar()
|
|
1053
|
+
|
|
1054
|
+
# Use a lock to ensure only one thread can access and modify
|
|
1055
|
+
# the shared dictionaries at a time.
|
|
1056
|
+
with self._lock:
|
|
1057
|
+
if operation_id in self._execution_progress_bar:
|
|
1058
|
+
pbar = self._execution_progress_bar[operation_id]
|
|
1059
|
+
if pbar.total != total_tasks:
|
|
1060
|
+
pbar.reset(
|
|
1061
|
+
total=total_tasks
|
|
1062
|
+
) # This force resets the progress bar % too on next refresh
|
|
1063
|
+
else:
|
|
1064
|
+
pbar = tqdm_pbar(
|
|
1065
|
+
total=total_tasks,
|
|
1066
|
+
leave=True,
|
|
1067
|
+
dynamic_ncols=True,
|
|
1068
|
+
bar_format="{l_bar}{bar} {n_fmt}/{total_fmt} Tasks",
|
|
1069
|
+
)
|
|
1070
|
+
self._execution_progress_bar[operation_id] = pbar
|
|
1071
|
+
|
|
1072
|
+
# To handle skipped or failed tasks.
|
|
1073
|
+
# StageInfo proto doesn't have skipped and failed tasks information to process.
|
|
1074
|
+
if done and completed_tasks < total_tasks:
|
|
1075
|
+
completed_tasks = total_tasks
|
|
1076
|
+
|
|
1077
|
+
pbar.n = completed_tasks
|
|
1078
|
+
pbar.refresh()
|
|
1079
|
+
|
|
1080
|
+
if done:
|
|
1081
|
+
pbar.close()
|
|
1082
|
+
self._execution_progress_bar.pop(operation_id, None)
|
|
1083
|
+
|
|
1084
|
+
self.registerProgressHandler(handler)
|
|
1085
|
+
|
|
586
1086
|
@staticmethod
|
|
587
1087
|
def _sql_lazy_transformation(req):
|
|
588
1088
|
# Select SQL command
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
return False
|
|
1089
|
+
try:
|
|
1090
|
+
query = req.plan.command.sql_command.input.sql.query
|
|
1091
|
+
return "select" in query.strip().lower().split()
|
|
1092
|
+
except AttributeError:
|
|
1093
|
+
return False
|
|
596
1094
|
|
|
597
1095
|
def _repr_html_(self) -> str:
|
|
598
1096
|
if not self._active_s8s_session_id:
|
|
@@ -600,7 +1098,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
600
1098
|
<div>No Active Dataproc Session</div>
|
|
601
1099
|
"""
|
|
602
1100
|
|
|
603
|
-
s8s_session = f"
|
|
1101
|
+
s8s_session = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{self._active_s8s_session_id}"
|
|
604
1102
|
ui = f"{s8s_session}/sparkApplications/applications"
|
|
605
1103
|
return f"""
|
|
606
1104
|
<div>
|
|
@@ -612,6 +1110,11 @@ class DataprocSparkSession(SparkSession):
|
|
|
612
1110
|
"""
|
|
613
1111
|
|
|
614
1112
|
def _display_operation_link(self, operation_id: str):
|
|
1113
|
+
# Don't print per-operation Spark UI link for non-interactive (despite
|
|
1114
|
+
# Ipython or non-IPython)
|
|
1115
|
+
if not environment.is_interactive():
|
|
1116
|
+
return
|
|
1117
|
+
|
|
615
1118
|
assert all(
|
|
616
1119
|
[
|
|
617
1120
|
operation_id is not None,
|
|
@@ -622,17 +1125,18 @@ class DataprocSparkSession(SparkSession):
|
|
|
622
1125
|
)
|
|
623
1126
|
|
|
624
1127
|
url = (
|
|
625
|
-
f"
|
|
1128
|
+
f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/"
|
|
626
1129
|
f"{self._active_s8s_session_id}/sparkApplications/application;"
|
|
627
1130
|
f"associatedSqlOperationId={operation_id}?project={self._project_id}"
|
|
628
1131
|
)
|
|
629
1132
|
|
|
1133
|
+
if environment.is_interactive_terminal():
|
|
1134
|
+
print(f"Spark Query: {url}")
|
|
1135
|
+
return
|
|
1136
|
+
|
|
630
1137
|
try:
|
|
631
1138
|
from IPython.display import display, HTML
|
|
632
|
-
from IPython.core.interactiveshell import InteractiveShell
|
|
633
1139
|
|
|
634
|
-
if not InteractiveShell.initialized():
|
|
635
|
-
return
|
|
636
1140
|
html_element = f"""
|
|
637
1141
|
<div>
|
|
638
1142
|
<p><a href="{url}">Spark Query</a> (Operation: {operation_id})</p>
|
|
@@ -690,7 +1194,7 @@ class DataprocSparkSession(SparkSession):
|
|
|
690
1194
|
This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws
|
|
691
1195
|
an exception.
|
|
692
1196
|
Regarding pypi: Popular packages are already pre-installed in s8s runtime.
|
|
693
|
-
https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.
|
|
1197
|
+
https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.3#python_libraries
|
|
694
1198
|
If there are conflicts/package doesn't exist, it throws an exception.
|
|
695
1199
|
"""
|
|
696
1200
|
if sum([pypi, file, pyfile, archive]) > 1:
|
|
@@ -713,19 +1217,83 @@ class DataprocSparkSession(SparkSession):
|
|
|
713
1217
|
def _get_active_session_file_path():
|
|
714
1218
|
return os.getenv("DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH")
|
|
715
1219
|
|
|
716
|
-
def stop(self) -> None:
|
|
1220
|
+
def stop(self, terminate: Optional[bool] = None) -> None:
|
|
1221
|
+
"""
|
|
1222
|
+
Stop the Spark session and optionally terminate the server-side session.
|
|
1223
|
+
|
|
1224
|
+
Parameters
|
|
1225
|
+
----------
|
|
1226
|
+
terminate : bool, optional
|
|
1227
|
+
Control server-side termination behavior.
|
|
1228
|
+
|
|
1229
|
+
- None (default): Auto-detect based on session type
|
|
1230
|
+
|
|
1231
|
+
- Managed sessions (auto-generated ID): terminate server
|
|
1232
|
+
- Named sessions (custom ID): client-side cleanup only
|
|
1233
|
+
|
|
1234
|
+
- True: Always terminate the server-side session
|
|
1235
|
+
- False: Never terminate the server-side session (client cleanup only)
|
|
1236
|
+
|
|
1237
|
+
Examples
|
|
1238
|
+
--------
|
|
1239
|
+
Auto-detect termination behavior (existing behavior):
|
|
1240
|
+
|
|
1241
|
+
>>> spark.stop()
|
|
1242
|
+
|
|
1243
|
+
Force terminate a named session:
|
|
1244
|
+
|
|
1245
|
+
>>> spark.stop(terminate=True)
|
|
1246
|
+
|
|
1247
|
+
Prevent termination of a managed session:
|
|
1248
|
+
|
|
1249
|
+
>>> spark.stop(terminate=False)
|
|
1250
|
+
"""
|
|
717
1251
|
with DataprocSparkSession._lock:
|
|
718
1252
|
if DataprocSparkSession._active_s8s_session_id is not None:
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
1253
|
+
# Determine if we should terminate the server-side session
|
|
1254
|
+
if terminate is None:
|
|
1255
|
+
# Auto-detect: managed sessions terminate, named sessions don't
|
|
1256
|
+
should_terminate = (
|
|
1257
|
+
not DataprocSparkSession._active_session_uses_custom_id
|
|
1258
|
+
)
|
|
1259
|
+
else:
|
|
1260
|
+
should_terminate = terminate
|
|
1261
|
+
|
|
1262
|
+
if should_terminate:
|
|
1263
|
+
# Terminate the server-side session
|
|
1264
|
+
logger.debug(
|
|
1265
|
+
f"Terminating session {DataprocSparkSession._active_s8s_session_id}"
|
|
1266
|
+
)
|
|
1267
|
+
terminate_s8s_session(
|
|
1268
|
+
DataprocSparkSession._project_id,
|
|
1269
|
+
DataprocSparkSession._region,
|
|
1270
|
+
DataprocSparkSession._active_s8s_session_id,
|
|
1271
|
+
self._client_options,
|
|
1272
|
+
)
|
|
1273
|
+
else:
|
|
1274
|
+
# Client-side cleanup only
|
|
1275
|
+
logger.debug(
|
|
1276
|
+
f"Stopping session {DataprocSparkSession._active_s8s_session_id} without termination"
|
|
1277
|
+
)
|
|
725
1278
|
|
|
726
1279
|
self._remove_stopped_session_from_file()
|
|
1280
|
+
|
|
1281
|
+
# Clean up SparkSession._instantiatedSession if it points to this session
|
|
1282
|
+
try:
|
|
1283
|
+
from pyspark.sql import SparkSession as PySparkSQLSession
|
|
1284
|
+
|
|
1285
|
+
if PySparkSQLSession._instantiatedSession is self:
|
|
1286
|
+
PySparkSQLSession._instantiatedSession = None
|
|
1287
|
+
logger.debug(
|
|
1288
|
+
"Cleared SparkSession._instantiatedSession reference"
|
|
1289
|
+
)
|
|
1290
|
+
except (ImportError, AttributeError):
|
|
1291
|
+
# PySpark not available or _instantiatedSession doesn't exist
|
|
1292
|
+
pass
|
|
1293
|
+
|
|
727
1294
|
DataprocSparkSession._active_s8s_session_uuid = None
|
|
728
1295
|
DataprocSparkSession._active_s8s_session_id = None
|
|
1296
|
+
DataprocSparkSession._active_session_uses_custom_id = False
|
|
729
1297
|
DataprocSparkSession._project_id = None
|
|
730
1298
|
DataprocSparkSession._region = None
|
|
731
1299
|
DataprocSparkSession._client_options = None
|