dataproc-spark-connect 0.2.1__py2.py3-none-any.whl → 0.7.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,36 +11,37 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import atexit
16
+ import datetime
14
17
  import json
15
18
  import logging
16
19
  import os
17
20
  import random
18
21
  import string
22
+ import threading
19
23
  import time
20
- import datetime
21
- from time import sleep
22
- from typing import Any, cast, ClassVar, Dict, Optional
24
+ import tqdm
23
25
 
24
26
  from google.api_core import retry
25
- from google.api_core.future.polling import POLLING_PREDICATE
26
27
  from google.api_core.client_options import ClientOptions
27
- from google.api_core.exceptions import FailedPrecondition, InvalidArgument, NotFound
28
- from google.cloud.dataproc_v1.types import sessions
29
-
28
+ from google.api_core.exceptions import Aborted, FailedPrecondition, InvalidArgument, NotFound, PermissionDenied
29
+ from google.api_core.future.polling import POLLING_PREDICATE
30
30
  from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
31
+ from google.cloud.dataproc_spark_connect.exceptions import DataprocSparkConnectException
32
+ from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
31
33
  from google.cloud.dataproc_v1 import (
34
+ AuthenticationConfig,
32
35
  CreateSessionRequest,
33
36
  GetSessionRequest,
34
37
  Session,
35
38
  SessionControllerClient,
36
- SessionTemplate,
37
39
  TerminateSessionRequest,
38
40
  )
39
- from google.protobuf import text_format
40
- from google.protobuf.text_format import ParseError
41
+ from google.cloud.dataproc_v1.types import sessions
41
42
  from pyspark.sql.connect.session import SparkSession
42
43
  from pyspark.sql.utils import to_str
43
-
44
+ from typing import Any, cast, ClassVar, Dict, Optional
44
45
 
45
46
  # Set up logging
46
47
  logging.basicConfig(level=logging.INFO)
@@ -61,11 +62,13 @@ class DataprocSparkSession(SparkSession):
61
62
  >>> spark = (
62
63
  ... DataprocSparkSession.builder
63
64
  ... .appName("Word Count")
64
- ... .dataprocConfig(Session())
65
+ ... .dataprocSessionConfig(Session())
65
66
  ... .getOrCreate()
66
67
  ... ) # doctest: +SKIP
67
68
  """
68
69
 
70
+ _DEFAULT_RUNTIME_VERSION = "2.2"
71
+
69
72
  _active_s8s_session_uuid: ClassVar[Optional[str]] = None
70
73
  _project_id = None
71
74
  _region = None
@@ -74,7 +77,12 @@ class DataprocSparkSession(SparkSession):
74
77
 
75
78
  class Builder(SparkSession.Builder):
76
79
 
77
- _dataproc_runtime_spark_version = {"3.0": "3.5.1", "2.2": "3.5.0"}
80
+ _dataproc_runtime_to_spark_version = {
81
+ "1.2": "3.5",
82
+ "2.2": "3.5",
83
+ "2.3": "3.5",
84
+ "3.0": "4.0",
85
+ }
78
86
 
79
87
  _session_static_configs = [
80
88
  "spark.executor.cores",
@@ -90,10 +98,10 @@ class DataprocSparkSession(SparkSession):
90
98
  self._options: Dict[str, Any] = {}
91
99
  self._channel_builder: Optional[DataprocChannelBuilder] = None
92
100
  self._dataproc_config: Optional[Session] = None
93
- self._project_id = os.environ.get("GOOGLE_CLOUD_PROJECT")
94
- self._region = os.environ.get("GOOGLE_CLOUD_REGION")
101
+ self._project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
102
+ self._region = os.getenv("GOOGLE_CLOUD_REGION")
95
103
  self._client_options = ClientOptions(
96
- api_endpoint=os.environ.get(
104
+ api_endpoint=os.getenv(
97
105
  "GOOGLE_CLOUD_DATAPROC_API_ENDPOINT",
98
106
  f"{self._region}-dataproc.googleapis.com",
99
107
  )
@@ -112,15 +120,15 @@ class DataprocSparkSession(SparkSession):
112
120
  self._project_id = project_id
113
121
  return self
114
122
 
115
- def region(self, region):
116
- self._region = region
117
- self._client_options.api_endpoint = os.environ.get(
123
+ def location(self, location):
124
+ self._region = location
125
+ self._client_options.api_endpoint = os.getenv(
118
126
  "GOOGLE_CLOUD_DATAPROC_API_ENDPOINT",
119
127
  f"{self._region}-dataproc.googleapis.com",
120
128
  )
121
129
  return self
122
130
 
123
- def dataprocConfig(self, dataproc_config: Session):
131
+ def dataprocSessionConfig(self, dataproc_config: Session):
124
132
  with self._lock:
125
133
  self._dataproc_config = dataproc_config
126
134
  for k, v in dataproc_config.runtime_config.properties.items():
@@ -135,14 +143,14 @@ class DataprocSparkSession(SparkSession):
135
143
  else:
136
144
  return self
137
145
 
138
- def create(self) -> "SparkSession":
146
+ def create(self) -> "DataprocSparkSession":
139
147
  raise NotImplemented(
140
148
  "DataprocSparkSession allows session creation only through getOrCreate"
141
149
  )
142
150
 
143
151
  def __create_spark_connect_session_from_s8s(
144
- self, session_response
145
- ) -> "SparkSession":
152
+ self, session_response, session_name
153
+ ) -> "DataprocSparkSession":
146
154
  DataprocSparkSession._active_s8s_session_uuid = (
147
155
  session_response.uuid
148
156
  )
@@ -152,10 +160,14 @@ class DataprocSparkSession(SparkSession):
152
160
  spark_connect_url = session_response.runtime_info.endpoints.get(
153
161
  "Spark Connect Server"
154
162
  )
155
- spark_connect_url = spark_connect_url.replace("https", "sc")
156
- url = f"{spark_connect_url.replace('.com/', '.com:443/')};session_id={session_response.uuid};use_ssl=true"
163
+ url = f"{spark_connect_url}/;session_id={session_response.uuid};use_ssl=true"
157
164
  logger.debug(f"Spark Connect URL: {url}")
158
- self._channel_builder = DataprocChannelBuilder(url)
165
+ self._channel_builder = DataprocChannelBuilder(
166
+ url,
167
+ is_active_callback=lambda: is_s8s_session_active(
168
+ session_name, self._client_options
169
+ ),
170
+ )
159
171
 
160
172
  assert self._channel_builder is not None
161
173
  session = DataprocSparkSession(connection=self._channel_builder)
@@ -164,77 +176,104 @@ class DataprocSparkSession(SparkSession):
164
176
  self.__apply_options(session)
165
177
  return session
166
178
 
167
- def __create(self) -> "SparkSession":
179
+ def __create(self) -> "DataprocSparkSession":
168
180
  with self._lock:
169
181
 
170
182
  if self._options.get("spark.remote", False):
171
183
  raise NotImplemented(
172
- "DataprocSparkSession does not support connecting to an existing remote server"
184
+ "DataprocSparkSession does not support connecting to an existing Spark Connect remote server"
173
185
  )
174
186
 
175
187
  from google.cloud.dataproc_v1 import SessionControllerClient
176
188
 
177
189
  dataproc_config: Session = self._get_dataproc_config()
178
- session_template: SessionTemplate = self._get_session_template()
179
190
 
180
- self._get_and_validate_version(
181
- dataproc_config, session_template
182
- )
191
+ self._validate_version(dataproc_config)
183
192
 
184
- spark_connect_session = self._get_spark_connect_session(
185
- dataproc_config, session_template
186
- )
187
-
188
- spark = self._get_spark(dataproc_config, session_template)
189
-
190
- if not spark_connect_session:
191
- dataproc_config.spark_connect_session = {}
192
- if not spark:
193
- dataproc_config.spark = {}
194
- os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
195
- session_request = CreateSessionRequest()
196
193
  session_id = self.generate_dataproc_session_id()
197
-
198
- session_request.session_id = session_id
199
194
  dataproc_config.name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
200
195
  logger.debug(
201
- f"Configurations used to create serverless session:\n {dataproc_config}"
196
+ f"Dataproc Session configuration:\n{dataproc_config}"
202
197
  )
198
+
199
+ session_request = CreateSessionRequest()
200
+ session_request.session_id = session_id
203
201
  session_request.session = dataproc_config
204
202
  session_request.parent = (
205
203
  f"projects/{self._project_id}/locations/{self._region}"
206
204
  )
207
205
 
208
- logger.debug("Creating serverless session")
206
+ logger.debug("Creating Dataproc Session")
209
207
  DataprocSparkSession._active_s8s_session_id = session_id
210
208
  s8s_creation_start_time = time.time()
211
- try:
212
- session_polling = retry.Retry(
213
- predicate=POLLING_PREDICATE,
214
- initial=5.0, # seconds
215
- maximum=5.0, # seconds
216
- multiplier=1.0,
217
- timeout=600, # seconds
218
- )
219
- logger.info(
220
- "Creating Spark session. It may take few minutes."
209
+
210
+ stop_create_session_pbar = False
211
+
212
+ def create_session_pbar():
213
+ iterations = 150
214
+ pbar = tqdm.trange(
215
+ iterations,
216
+ bar_format="{bar}",
217
+ ncols=80,
221
218
  )
219
+ for i in pbar:
220
+ if stop_create_session_pbar:
221
+ break
222
+ # Last iteration
223
+ if i >= iterations - 1:
224
+ # Sleep until session created
225
+ while not stop_create_session_pbar:
226
+ time.sleep(1)
227
+ else:
228
+ time.sleep(1)
229
+
230
+ pbar.close()
231
+ # Print new line after the progress bar
232
+ print()
233
+
234
+ create_session_pbar_thread = threading.Thread(
235
+ target=create_session_pbar
236
+ )
237
+
238
+ try:
239
+ if (
240
+ os.getenv(
241
+ "DATAPROC_SPARK_CONNECT_SESSION_TERMINATE_AT_EXIT",
242
+ "false",
243
+ )
244
+ == "true"
245
+ ):
246
+ atexit.register(
247
+ lambda: terminate_s8s_session(
248
+ self._project_id,
249
+ self._region,
250
+ session_id,
251
+ self._client_options,
252
+ )
253
+ )
222
254
  operation = SessionControllerClient(
223
255
  client_options=self._client_options
224
256
  ).create_session(session_request)
225
257
  print(
226
- f"Interactive Session Detail View: https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}"
258
+ f"Creating Dataproc Session: https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}?project={self._project_id}"
227
259
  )
260
+ create_session_pbar_thread.start()
228
261
  session_response: Session = operation.result(
229
- polling=session_polling
262
+ polling=retry.Retry(
263
+ predicate=POLLING_PREDICATE,
264
+ initial=5.0, # seconds
265
+ maximum=5.0, # seconds
266
+ multiplier=1.0,
267
+ timeout=600, # seconds
268
+ )
230
269
  )
231
- if (
232
- "DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH"
233
- in os.environ
234
- ):
235
- file_path = os.environ[
236
- "DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH"
237
- ]
270
+ stop_create_session_pbar = True
271
+ create_session_pbar_thread.join()
272
+ print("Dataproc Session was successfully created")
273
+ file_path = (
274
+ DataprocSparkSession._get_active_session_file_path()
275
+ )
276
+ if file_path is not None:
238
277
  try:
239
278
  session_data = {
240
279
  "session_name": session_response.name,
@@ -247,75 +286,67 @@ class DataprocSparkSession(SparkSession):
247
286
  json.dump(session_data, json_file, indent=4)
248
287
  except Exception as e:
249
288
  logger.error(
250
- f"Exception while writing active session to file {file_path} , {e}"
289
+ f"Exception while writing active session to file {file_path}, {e}"
251
290
  )
252
- except InvalidArgument as e:
291
+ except (InvalidArgument, PermissionDenied) as e:
292
+ stop_create_session_pbar = True
293
+ if create_session_pbar_thread.is_alive():
294
+ create_session_pbar_thread.join()
253
295
  DataprocSparkSession._active_s8s_session_id = None
254
- raise RuntimeError(
255
- f"Error while creating serverless session: {e}"
256
- ) from None
296
+ raise DataprocSparkConnectException(
297
+ f"Error while creating Dataproc Session: {e.message}"
298
+ )
257
299
  except Exception as e:
300
+ stop_create_session_pbar = True
301
+ if create_session_pbar_thread.is_alive():
302
+ create_session_pbar_thread.join()
258
303
  DataprocSparkSession._active_s8s_session_id = None
259
304
  raise RuntimeError(
260
- f"Error while creating serverless session https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id} : {e}"
261
- ) from None
305
+ f"Error while creating Dataproc Session"
306
+ ) from e
262
307
 
263
308
  logger.debug(
264
- f"Serverless session created: {session_id}, creation time taken: {int(time.time() - s8s_creation_start_time)} seconds"
309
+ f"Dataproc Session created: {session_id} in {int(time.time() - s8s_creation_start_time)} seconds"
265
310
  )
266
311
  return self.__create_spark_connect_session_from_s8s(
267
- session_response
312
+ session_response, dataproc_config.name
268
313
  )
269
314
 
270
- def _is_s8s_session_active(
271
- self, s8s_session_id: str
272
- ) -> Optional[sessions.Session]:
273
- session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
274
- get_session_request = GetSessionRequest()
275
- get_session_request.name = session_name
276
- state = None
277
- try:
278
- get_session_response = SessionControllerClient(
279
- client_options=self._client_options
280
- ).get_session(get_session_request)
281
- state = get_session_response.state
282
- except Exception as e:
283
- logger.debug(f"{s8s_session_id} deleted: {e}")
284
- return None
285
-
286
- if state is not None and (
287
- state == Session.State.ACTIVE or state == Session.State.CREATING
288
- ):
289
- return get_session_response
290
- return None
291
-
292
- def _get_exiting_active_session(self) -> Optional["SparkSession"]:
315
+ def _get_exiting_active_session(
316
+ self,
317
+ ) -> Optional["DataprocSparkSession"]:
293
318
  s8s_session_id = DataprocSparkSession._active_s8s_session_id
294
- session_response = self._is_s8s_session_active(s8s_session_id)
319
+ session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
320
+ session_response = None
321
+ session = None
322
+ if s8s_session_id is not None:
323
+ session_response = get_active_s8s_session_response(
324
+ session_name, self._client_options
325
+ )
326
+ session = DataprocSparkSession.getActiveSession()
295
327
 
296
- session = DataprocSparkSession.getActiveSession()
297
328
  if session is None:
298
329
  session = DataprocSparkSession._default_session
299
330
 
300
331
  if session_response is not None:
301
332
  print(
302
- f"Using existing session: https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}, configuration changes may not be applied."
333
+ f"Using existing Dataproc Session (configuration changes may not be applied): https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}?project={self._project_id}"
303
334
  )
304
335
  if session is None:
305
336
  session = self.__create_spark_connect_session_from_s8s(
306
- session_response
337
+ session_response, session_name
307
338
  )
308
339
  return session
309
340
  else:
310
- logger.info(
311
- f"Session: {s8s_session_id} not active, stopping previous spark session and creating new"
312
- )
313
341
  if session is not None:
342
+ print(
343
+ f"{s8s_session_id} Dataproc Session is not active, stopping and creating a new one"
344
+ )
314
345
  session.stop()
315
346
 
316
347
  return None
317
348
 
318
- def getOrCreate(self) -> "SparkSession":
349
+ def getOrCreate(self) -> "DataprocSparkSession":
319
350
  with DataprocSparkSession._lock:
320
351
  session = self._get_exiting_active_session()
321
352
  if session is None:
@@ -330,21 +361,52 @@ class DataprocSparkSession(SparkSession):
330
361
  dataproc_config = self._dataproc_config
331
362
  for k, v in self._options.items():
332
363
  dataproc_config.runtime_config.properties[k] = v
333
- elif "DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG" in os.environ:
334
- filepath = os.environ[
335
- "DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG"
364
+ dataproc_config.spark_connect_session = (
365
+ sessions.SparkConnectConfig()
366
+ )
367
+ if not dataproc_config.runtime_config.version:
368
+ dataproc_config.runtime_config.version = (
369
+ DataprocSparkSession._DEFAULT_RUNTIME_VERSION
370
+ )
371
+ if (
372
+ not dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type
373
+ and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
374
+ ):
375
+ dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
376
+ os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
336
377
  ]
337
- try:
338
- with open(filepath, "r") as f:
339
- dataproc_config = Session.wrap(
340
- text_format.Parse(
341
- f.read(), Session.pb(dataproc_config)
342
- )
343
- )
344
- except FileNotFoundError:
345
- raise FileNotFoundError(f"File '{filepath}' not found")
346
- except ParseError as e:
347
- raise ParseError(f"Error parsing file '{filepath}': {e}")
378
+ if (
379
+ not dataproc_config.environment_config.execution_config.service_account
380
+ and "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT" in os.environ
381
+ ):
382
+ dataproc_config.environment_config.execution_config.service_account = os.getenv(
383
+ "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT"
384
+ )
385
+ if (
386
+ not dataproc_config.environment_config.execution_config.subnetwork_uri
387
+ and "DATAPROC_SPARK_CONNECT_SUBNET" in os.environ
388
+ ):
389
+ dataproc_config.environment_config.execution_config.subnetwork_uri = os.getenv(
390
+ "DATAPROC_SPARK_CONNECT_SUBNET"
391
+ )
392
+ if (
393
+ not dataproc_config.environment_config.execution_config.ttl
394
+ and "DATAPROC_SPARK_CONNECT_TTL_SECONDS" in os.environ
395
+ ):
396
+ dataproc_config.environment_config.execution_config.ttl = {
397
+ "seconds": int(
398
+ os.getenv("DATAPROC_SPARK_CONNECT_TTL_SECONDS")
399
+ )
400
+ }
401
+ if (
402
+ not dataproc_config.environment_config.execution_config.idle_ttl
403
+ and "DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS" in os.environ
404
+ ):
405
+ dataproc_config.environment_config.execution_config.idle_ttl = {
406
+ "seconds": int(
407
+ os.getenv("DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS")
408
+ )
409
+ }
348
410
  if "COLAB_NOTEBOOK_RUNTIME_ID" in os.environ:
349
411
  dataproc_config.labels["colab-notebook-runtime-id"] = (
350
412
  os.environ["COLAB_NOTEBOOK_RUNTIME_ID"]
@@ -355,95 +417,38 @@ class DataprocSparkSession(SparkSession):
355
417
  ]
356
418
  return dataproc_config
357
419
 
358
- def _get_session_template(self):
359
- from google.cloud.dataproc_v1 import (
360
- GetSessionTemplateRequest,
361
- SessionTemplateControllerClient,
362
- )
363
-
364
- session_template = None
365
- if self._dataproc_config and self._dataproc_config.session_template:
366
- session_template = self._dataproc_config.session_template
367
- get_session_template_request = GetSessionTemplateRequest()
368
- get_session_template_request.name = session_template
369
- client = SessionTemplateControllerClient(
370
- client_options=self._client_options
371
- )
372
- try:
373
- session_template = client.get_session_template(
374
- get_session_template_request
375
- )
376
- except Exception as e:
377
- logger.error(
378
- f"Failed to get session template {session_template}: {e}"
379
- )
380
- raise
381
- return session_template
420
+ def _validate_version(self, dataproc_config):
421
+ trim_version = lambda v: ".".join(v.split(".")[:2])
382
422
 
383
- def _get_and_validate_version(self, dataproc_config, session_template):
384
- trimmed_version = lambda v: ".".join(v.split(".")[:2])
385
- version = None
423
+ version = dataproc_config.runtime_config.version
386
424
  if (
387
- dataproc_config
388
- and dataproc_config.runtime_config
389
- and dataproc_config.runtime_config.version
390
- ):
391
- version = dataproc_config.runtime_config.version
392
- elif (
393
- session_template
394
- and session_template.runtime_config
395
- and session_template.runtime_config.version
396
- ):
397
- version = session_template.runtime_config.version
398
-
399
- if not version:
400
- version = "3.0"
401
- dataproc_config.runtime_config.version = version
402
- elif (
403
- trimmed_version(version)
404
- not in self._dataproc_runtime_spark_version
425
+ trim_version(version)
426
+ not in self._dataproc_runtime_to_spark_version
405
427
  ):
406
428
  raise ValueError(
407
- f"runtime_config.version {version} is not supported. "
408
- f"Supported versions: {self._dataproc_runtime_spark_version.keys()}"
429
+ f"Specified {version} Dataproc Spark runtime version is not supported. "
430
+ f"Supported runtime versions: {self._dataproc_runtime_to_spark_version.keys()}"
409
431
  )
410
432
 
411
- server_version = self._dataproc_runtime_spark_version[
412
- trimmed_version(version)
433
+ server_version = self._dataproc_runtime_to_spark_version[
434
+ trim_version(version)
413
435
  ]
436
+
414
437
  import importlib.metadata
415
438
 
416
439
  dataproc_connect_version = importlib.metadata.version(
417
440
  "dataproc-spark-connect"
418
441
  )
419
442
  client_version = importlib.metadata.version("pyspark")
420
- version_message = f"Dataproc Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Dataproc Session Runtime: {version} (Spark: {server_version})"
421
- logger.info(version_message)
422
- if trimmed_version(client_version) != trimmed_version(
423
- server_version
424
- ):
425
- logger.warning(
426
- f"client and server on different versions: {version_message}"
443
+ if trim_version(client_version) != trim_version(server_version):
444
+ print(
445
+ f"Spark Connect client and server use different versions:\n"
446
+ f"- Dataproc Spark Connect client {dataproc_connect_version} (PySpark {client_version})\n"
447
+ f"- Dataproc Spark runtime {version} (Spark {server_version})"
427
448
  )
428
- return version
429
-
430
- def _get_spark_connect_session(self, dataproc_config, session_template):
431
- spark_connect_session = None
432
- if dataproc_config and dataproc_config.spark_connect_session:
433
- spark_connect_session = dataproc_config.spark_connect_session
434
- elif session_template and session_template.spark_connect_session:
435
- spark_connect_session = session_template.spark_connect_session
436
- return spark_connect_session
437
-
438
- def _get_spark(self, dataproc_config, session_template):
439
- spark = None
440
- if dataproc_config and dataproc_config.spark:
441
- spark = dataproc_config.spark
442
- elif session_template and session_template.spark:
443
- spark = session_template.spark
444
- return spark
445
-
446
- def generate_dataproc_session_id(self):
449
+
450
+ @staticmethod
451
+ def generate_dataproc_session_id():
447
452
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
448
453
  suffix_length = 6
449
454
  random_suffix = "".join(
@@ -456,75 +461,102 @@ class DataprocSparkSession(SparkSession):
456
461
  def _repr_html_(self) -> str:
457
462
  if not self._active_s8s_session_id:
458
463
  return """
459
- <div>No Active Dataproc Spark Session</div>
464
+ <div>No Active Dataproc Session</div>
460
465
  """
461
466
 
462
467
  s8s_session = f"https://console.cloud.google.com/dataproc/interactive/{self._region}/{self._active_s8s_session_id}"
463
468
  ui = f"{s8s_session}/sparkApplications/applications"
464
- version = ""
465
469
  return f"""
466
470
  <div>
467
471
  <p><b>Spark Connect</b></p>
468
472
 
469
- <p><a href="{s8s_session}">Dataproc Session</a></p>
470
- <p><a href="{ui}">Spark UI</a></p>
473
+ <p><a href="{s8s_session}?project={self._project_id}">Dataproc Session</a></p>
474
+ <p><a href="{ui}?project={self._project_id}">Spark UI</a></p>
471
475
  </div>
472
476
  """
473
477
 
474
- def _remove_stoped_session_from_file(self):
475
- if "DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH" in os.environ:
476
- file_path = os.environ[
477
- "DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH"
478
- ]
478
+ @staticmethod
479
+ def _remove_stopped_session_from_file():
480
+ file_path = DataprocSparkSession._get_active_session_file_path()
481
+ if file_path is not None:
479
482
  try:
480
483
  with open(file_path, "w"):
481
484
  pass
482
485
  except Exception as e:
483
486
  logger.error(
484
- f"Exception while removing active session in file {file_path} , {e}"
487
+ f"Exception while removing active session in file {file_path}, {e}"
485
488
  )
486
489
 
490
+ def addArtifacts(
491
+ self,
492
+ *artifact: str,
493
+ pyfile: bool = False,
494
+ archive: bool = False,
495
+ file: bool = False,
496
+ pypi: bool = False,
497
+ ) -> None:
498
+ """
499
+ Add artifact(s) to the client session. Currently only local files & pypi installations are supported.
500
+
501
+ .. versionadded:: 3.5.0
502
+
503
+ Parameters
504
+ ----------
505
+ *artifact : tuple of str
506
+ Artifact's URIs to add.
507
+ pyfile : bool
508
+ Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files.
509
+ The pyfiles are directly inserted into the path when executing Python functions
510
+ in executors.
511
+ archive : bool
512
+ Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
513
+ The archives are unpacked on the executor side automatically.
514
+ file : bool
515
+ Add a file to be downloaded with this Spark job on every node.
516
+ The ``path`` passed can only be a local file for now.
517
+ pypi : bool
518
+ This option is only available with DataprocSparkSession. e.g. `spark.addArtifacts("spacy==3.8.4", "torch", pypi=True)`
519
+ Installs PyPi package (with its dependencies) in the active Spark session on the driver and executors.
520
+
521
+ Notes
522
+ -----
523
+ This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws
524
+ an exception.
525
+ Regarding pypi: Popular packages are already pre-installed in s8s runtime.
526
+ https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.2#python_libraries
527
+ If there are conflicts/package doesn't exist, it throws an exception.
528
+ """
529
+ if sum([pypi, file, pyfile, archive]) > 1:
530
+ raise ValueError(
531
+ "'pyfile', 'archive', 'file' and/or 'pypi' cannot be True together."
532
+ )
533
+ if pypi:
534
+ artifacts = PyPiArtifacts(set(artifact))
535
+ logger.debug("Making addArtifact call to install packages")
536
+ self.addArtifact(
537
+ artifacts.write_packages_config(self._active_s8s_session_uuid),
538
+ file=True,
539
+ )
540
+ else:
541
+ super().addArtifacts(
542
+ *artifact, pyfile=pyfile, archive=archive, file=file
543
+ )
544
+
545
+ @staticmethod
546
+ def _get_active_session_file_path():
547
+ return os.getenv("DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH")
548
+
487
549
  def stop(self) -> None:
488
550
  with DataprocSparkSession._lock:
489
551
  if DataprocSparkSession._active_s8s_session_id is not None:
490
- from google.cloud.dataproc_v1 import SessionControllerClient
491
-
492
- logger.debug(
493
- f"Terminating serverless session: {DataprocSparkSession._active_s8s_session_id}"
552
+ terminate_s8s_session(
553
+ DataprocSparkSession._project_id,
554
+ DataprocSparkSession._region,
555
+ DataprocSparkSession._active_s8s_session_id,
556
+ self._client_options,
494
557
  )
495
- terminate_session_request = TerminateSessionRequest()
496
- session_name = f"projects/{DataprocSparkSession._project_id}/locations/{DataprocSparkSession._region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
497
- terminate_session_request.name = session_name
498
- state = None
499
- try:
500
- SessionControllerClient(
501
- client_options=self._client_options
502
- ).terminate_session(terminate_session_request)
503
- get_session_request = GetSessionRequest()
504
- get_session_request.name = session_name
505
- state = Session.State.ACTIVE
506
- while (
507
- state != Session.State.TERMINATING
508
- and state != Session.State.TERMINATED
509
- and state != Session.State.FAILED
510
- ):
511
- session = SessionControllerClient(
512
- client_options=self._client_options
513
- ).get_session(get_session_request)
514
- state = session.state
515
- sleep(1)
516
- except NotFound:
517
- logger.debug(
518
- f"Session {DataprocSparkSession._active_s8s_session_id} already deleted"
519
- )
520
- except FailedPrecondition:
521
- logger.debug(
522
- f"Session {DataprocSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
523
- )
524
- if state is not None and state == Session.State.FAILED:
525
- raise RuntimeError("Serverless session termination failed")
526
558
 
527
- self._remove_stoped_session_from_file()
559
+ self._remove_stopped_session_from_file()
528
560
  DataprocSparkSession._active_s8s_session_uuid = None
529
561
  DataprocSparkSession._active_s8s_session_id = None
530
562
  DataprocSparkSession._project_id = None
@@ -538,3 +570,68 @@ class DataprocSparkSession(SparkSession):
538
570
  DataprocSparkSession._active_session, "session", None
539
571
  ):
540
572
  DataprocSparkSession._active_session.session = None
573
+
574
+
575
+ def terminate_s8s_session(
576
+ project_id, region, active_s8s_session_id, client_options=None
577
+ ):
578
+ from google.cloud.dataproc_v1 import SessionControllerClient
579
+
580
+ logger.debug(f"Terminating Dataproc Session: {active_s8s_session_id}")
581
+ terminate_session_request = TerminateSessionRequest()
582
+ session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}"
583
+ terminate_session_request.name = session_name
584
+ state = None
585
+ try:
586
+ session_client = SessionControllerClient(client_options=client_options)
587
+ session_client.terminate_session(terminate_session_request)
588
+ get_session_request = GetSessionRequest()
589
+ get_session_request.name = session_name
590
+ state = Session.State.ACTIVE
591
+ while (
592
+ state != Session.State.TERMINATING
593
+ and state != Session.State.TERMINATED
594
+ and state != Session.State.FAILED
595
+ ):
596
+ session = session_client.get_session(get_session_request)
597
+ state = session.state
598
+ time.sleep(1)
599
+ except NotFound:
600
+ logger.debug(
601
+ f"{active_s8s_session_id} Dataproc Session already deleted"
602
+ )
603
+ # Client will get 'Aborted' error if session creation is still in progress and
604
+ # 'FailedPrecondition' if another termination is still in progress.
605
+ # Both are retryable, but we catch it and let TTL take care of cleanups.
606
+ except (FailedPrecondition, Aborted):
607
+ logger.debug(
608
+ f"{active_s8s_session_id} Dataproc Session already terminated manually or automatically due to TTL"
609
+ )
610
+ if state is not None and state == Session.State.FAILED:
611
+ raise RuntimeError("Dataproc Session termination failed")
612
+
613
+
614
+ def get_active_s8s_session_response(
615
+ session_name, client_options
616
+ ) -> Optional[sessions.Session]:
617
+ get_session_request = GetSessionRequest()
618
+ get_session_request.name = session_name
619
+ try:
620
+ get_session_response = SessionControllerClient(
621
+ client_options=client_options
622
+ ).get_session(get_session_request)
623
+ state = get_session_response.state
624
+ except Exception as e:
625
+ print(f"{session_name} Dataproc Session deleted: {e}")
626
+ return None
627
+ if state is not None and (
628
+ state == Session.State.ACTIVE or state == Session.State.CREATING
629
+ ):
630
+ return get_session_response
631
+ return None
632
+
633
+
634
+ def is_s8s_session_active(session_name, client_options) -> bool:
635
+ if get_active_s8s_session_response(session_name, client_options) is None:
636
+ return False
637
+ return True