apache-airflow-providers-teradata 3.2.3rc1__py3-none-any.whl → 3.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,640 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from typing import TYPE_CHECKING
21
+
22
+ from airflow.models import BaseOperator
23
+ from airflow.providers.ssh.hooks.ssh import SSHHook
24
+ from airflow.providers.teradata.hooks.teradata import TeradataHook
25
+ from airflow.providers.teradata.hooks.tpt import TptHook
26
+ from airflow.providers.teradata.utils.tpt_util import (
27
+ get_remote_temp_directory,
28
+ is_valid_file,
29
+ is_valid_remote_job_var_file,
30
+ prepare_tdload_job_var_file,
31
+ prepare_tpt_ddl_script,
32
+ read_file,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from paramiko import SSHClient
37
+
38
+ from airflow.sdk import Context
39
+
40
+
41
+ class DdlOperator(BaseOperator):
42
+ """
43
+ Operator to execute one or more DDL (Data Definition Language) statements on a Teradata Database.
44
+
45
+ This operator is designed to facilitate DDL operations such as creating, altering, or dropping tables, indexes, views, or other database objects in a scalable and efficient manner.
46
+
47
+ It leverages the TPT (Teradata Parallel Transporter) utility to perform the operations and supports templating for SQL statements, allowing dynamic generation of SQL at runtime.
48
+
49
+ Key Features:
50
+ - Executes one or more DDL statements sequentially on Teradata using TPT
51
+ - Supports error handling with customizable error code list
52
+ - Supports XCom push to share execution results with downstream tasks
53
+ - Integrates with Airflow's templating engine for dynamic SQL generation
54
+ - Can execute statements via SSH connection if needed
55
+
56
+ :param ddl: A list of DDL statements to be executed. Each item should be a valid SQL
57
+ DDL command supported by Teradata.
58
+ :param error_list: Optional integer or list of error codes to ignore during execution.
59
+ If provided, the operator will not fail when these specific error codes occur.
60
+ Example: error_list=3803 or error_list=[3803, 3807]
61
+ :param teradata_conn_id: The connection ID for the Teradata database.
62
+ Defaults to TeradataHook.default_conn_name.
63
+ :param ssh_conn_id: Optional SSH connection ID if the commands need to be executed through SSH.
64
+ :param remote_working_dir: Directory on the remote server where temporary files will be stored.
65
+ :param ddl_job_name: Optional name for the DDL job.
66
+ :raises ValueError: If the ddl parameter or error_list is invalid.
67
+ :raises RuntimeError: If underlying TPT execution (tbuild) fails with non-zero exit status.
68
+ :raises ConnectionError: If remote SSH connection cannot be established.
69
+ :raises TimeoutError: If SSH connection attempt times out.
70
+ :raises FileNotFoundError: If required TPT utility (tbuild) is missing locally or on remote host.
71
+
72
+ Example usage::
73
+
74
+ # Example of creating tables using DdlOperator
75
+ create_tables = DdlOperator(
76
+ task_id="create_tables_task",
77
+ ddl=[
78
+ "CREATE TABLE my_database.my_table1 (id INT, name VARCHAR(100))",
79
+ "CREATE TABLE my_database.my_table2 (id INT, value FLOAT)",
80
+ ],
81
+ teradata_conn_id="my_teradata_conn",
82
+ error_list=[3803], # Ignore "Table already exists" errors
83
+ ddl_job_name="create_tables_job",
84
+ )
85
+
86
+ # Example of dropping tables using DdlOperator
87
+ drop_tables = DdlOperator(
88
+ task_id="drop_tables_task",
89
+ ddl=["DROP TABLE my_database.my_table1", "DROP TABLE my_database.my_table2"],
90
+ teradata_conn_id="my_teradata_conn",
91
+ error_list=3807, # Ignore "Object does not exist" errors
92
+ ddl_job_name="drop_tables_job",
93
+ )
94
+
95
+ # Example using templated SQL file
96
+ alter_table = DdlOperator(
97
+ task_id="alter_table_task",
98
+ ddl="{{ var.value.get('ddl_directory') }}/alter_table.sql",
99
+ teradata_conn_id="my_teradata_conn",
100
+ ssh_conn_id="my_ssh_conn",
101
+ ddl_job_name="alter_table_job",
102
+ )
103
+ """
104
+
105
+ template_fields = ("ddl", "ddl_job_name")
106
+ template_ext = (".sql",)
107
+ ui_color = "#a8e4b1"
108
+
109
+ def __init__(
110
+ self,
111
+ *,
112
+ ddl: list[str],
113
+ error_list: int | list[int] | None = None,
114
+ teradata_conn_id: str = TeradataHook.default_conn_name,
115
+ ssh_conn_id: str | None = None,
116
+ remote_working_dir: str | None = None,
117
+ ddl_job_name: str | None = None,
118
+ **kwargs,
119
+ ) -> None:
120
+ super().__init__(**kwargs)
121
+ self.ddl = ddl
122
+ self.error_list = error_list
123
+ self.teradata_conn_id = teradata_conn_id
124
+ self.ssh_conn_id = ssh_conn_id
125
+ self.remote_working_dir = remote_working_dir
126
+ self.ddl_job_name = ddl_job_name
127
+ self._hook: TptHook | None = None
128
+ self._ssh_hook: SSHHook | None = None
129
+
130
+ def execute(self, context: Context) -> int | None:
131
+ """Execute the DDL operations using the TptHook."""
132
+ # Validate the ddl parameter
133
+ if (
134
+ not self.ddl
135
+ or not isinstance(self.ddl, list)
136
+ or not all(isinstance(stmt, str) and stmt.strip() for stmt in self.ddl)
137
+ ):
138
+ raise ValueError(
139
+ "ddl parameter must be a non-empty list of non-empty strings representing DDL statements."
140
+ )
141
+
142
+ # Normalize error_list to a list of ints
143
+ normalized_error_list = self._normalize_error_list(self.error_list)
144
+
145
+ self.log.info("Initializing Teradata connection using teradata_conn_id: %s", self.teradata_conn_id)
146
+ self._hook = TptHook(teradata_conn_id=self.teradata_conn_id, ssh_conn_id=self.ssh_conn_id)
147
+ self._ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if self.ssh_conn_id else None
148
+
149
+ try:
150
+ # Prepare TPT script for DDL execution
151
+ tpt_ddl_script = prepare_tpt_ddl_script(
152
+ sql=self.ddl,
153
+ error_list=normalized_error_list,
154
+ source_conn=self._hook.get_conn(),
155
+ job_name=self.ddl_job_name,
156
+ )
157
+
158
+ # Set remote working directory if SSH is used
159
+ if self._ssh_hook and not self.remote_working_dir:
160
+ self.remote_working_dir = get_remote_temp_directory(
161
+ self._ssh_hook.get_conn(), logging.getLogger(__name__)
162
+ )
163
+ # Ensure remote_working_dir has a value even for local execution
164
+ if not self.remote_working_dir:
165
+ self.remote_working_dir = "/tmp"
166
+
167
+ return self._hook.execute_ddl(
168
+ tpt_ddl_script,
169
+ self.remote_working_dir,
170
+ )
171
+ except Exception as e:
172
+ self.log.error("Failed to execute DDL operations: %s", str(e))
173
+ raise
174
+
175
+ def _normalize_error_list(self, error_list: int | list[int] | None) -> list[int]:
176
+ """
177
+ Normalize error_list parameter to a list of integers.
178
+
179
+ Args:
180
+ error_list: An integer, list of integers, or None
181
+
182
+ Returns:
183
+ A list of integers representing error codes to ignore
184
+
185
+ Raises:
186
+ ValueError: If error_list is not of the expected type
187
+ """
188
+ if error_list is None:
189
+ return []
190
+ if isinstance(error_list, int):
191
+ return [error_list]
192
+ if isinstance(error_list, list) and all(isinstance(err, int) for err in error_list):
193
+ return error_list
194
+ raise ValueError(
195
+ f"error_list must be an int or a list of ints, got {type(error_list).__name__}. "
196
+ "Example: error_list=3803 or error_list=[3803, 3807]"
197
+ )
198
+
199
+ def on_kill(self):
200
+ """Handle termination signals and ensure the hook is properly cleaned up."""
201
+ self.log.info("Cleaning up TPT DDL connections on task kill")
202
+ if self._hook:
203
+ try:
204
+ self._hook.on_kill()
205
+ self.log.info("TPT DDL hook cleaned up successfully")
206
+ except Exception as e:
207
+ self.log.error("Error cleaning up TPT DDL hook: %s", str(e))
208
+ else:
209
+ self.log.warning("No TptHook initialized to clean up on task kill")
210
+
211
+
212
+ class TdLoadOperator(BaseOperator):
213
+ """
214
+ Operator to handle data transfers using Teradata Parallel Transporter (TPT) tdload utility.
215
+
216
+ This operator supports three main scenarios:
217
+ 1. Load data from a file to a Teradata table
218
+ 2. Export data from a Teradata table to a file
219
+ 3. Transfer data between two Teradata tables (potentially across different databases)
220
+
221
+ For all scenarios:
222
+ :param teradata_conn_id: Connection ID for Teradata database (source for table operations)
223
+
224
+ For file to table loading:
225
+ :param source_file_name: Path to the source file (required for file to table)
226
+ :param select_stmt: SQL SELECT statement to filter data (optional)
227
+ :param insert_stmt: SQL INSERT statement to use for loading data (optional)
228
+ :param target_table: Name of the target table (required for file to table)
229
+ :param target_teradata_conn_id: Connection ID for target Teradata database (defaults to teradata_conn_id)
230
+
231
+ For table to file export:
232
+ :param source_table: Name of the source table (required for table to file)
233
+ :param target_file_name: Path to the target file (required for table to file)
234
+
235
+ For table to table transfer:
236
+ :param source_table: Name of the source table (required for table to table)
237
+ :param select_stmt: SQL SELECT statement to filter data (optional)
238
+ :param insert_stmt: SQL INSERT statement to use for loading data (optional)
239
+ :param target_table: Name of the target table (required for table to table)
240
+ :param target_teradata_conn_id: Connection ID for target Teradata database (required for table to table)
241
+
242
+ Optional configuration parameters:
243
+ :param source_format: Format of source data (default: 'Delimited')
244
+ :param target_format: Format of target data (default: 'Delimited')
245
+ :param source_text_delimiter: Source text delimiter (default: ',')
246
+ :param target_text_delimiter: Target text delimiter (default: ',')
247
+ :param tdload_options: Additional options for tdload (optional)
248
+ :param tdload_job_name: Name for the tdload job (optional)
249
+ :param tdload_job_var_file: Path to tdload job variable file (optional)
250
+ :param ssh_conn_id: SSH connection ID for secure file transfer (optional, used for file operations)
251
+
252
+ :raises ValueError: If parameter combinations are invalid or required files are missing.
253
+ :raises RuntimeError: If underlying TPT execution (tdload) fails with non-zero exit status.
254
+ :raises ConnectionError: If remote SSH connection cannot be established.
255
+ :raises TimeoutError: If SSH connection attempt times out.
256
+ :raises FileNotFoundError: If required TPT utility (tdload) is missing locally or on remote host.
257
+
258
+ Example usage::
259
+
260
+ # Example usage for file to table:
261
+ load_file = TdLoadOperator(
262
+ task_id="load_from_file",
263
+ source_file_name="/path/to/data.csv",
264
+ target_table="my_database.my_table",
265
+ target_teradata_conn_id="teradata_target_conn",
266
+ insert_stmt="INSERT INTO my_database.my_table (col1, col2) VALUES (?, ?)",
267
+ )
268
+
269
+ # Example usage for table to file:
270
+ export_data = TdLoadOperator(
271
+ task_id="export_to_file",
272
+ source_table="my_database.my_table",
273
+ target_file_name="/path/to/export.csv",
274
+ teradata_conn_id="teradata_source_conn",
275
+ ssh_conn_id="ssh_default",
276
+ tdload_job_name="export_job",
277
+ )
278
+
279
+ # Example usage for table to table:
280
+ transfer_data = TdLoadOperator(
281
+ task_id="transfer_between_tables",
282
+ source_table="source_db.source_table",
283
+ target_table="target_db.target_table",
284
+ teradata_conn_id="teradata_source_conn",
285
+ target_teradata_conn_id="teradata_target_conn",
286
+ tdload_job_var_file="/path/to/vars.txt",
287
+ insert_stmt="INSERT INTO target_db.target_table (col1, col2) VALUES (?, ?)",
288
+ )
289
+
290
+
291
+ """
292
+
293
+ template_fields = (
294
+ "source_table",
295
+ "target_table",
296
+ "select_stmt",
297
+ "insert_stmt",
298
+ "source_file_name",
299
+ "target_file_name",
300
+ "tdload_options",
301
+ )
302
+ ui_color = "#a8e4b1"
303
+
304
+ def __init__(
305
+ self,
306
+ *,
307
+ teradata_conn_id: str = TeradataHook.default_conn_name,
308
+ target_teradata_conn_id: str | None = None,
309
+ ssh_conn_id: str | None = None,
310
+ source_table: str | None = None,
311
+ select_stmt: str | None = None,
312
+ insert_stmt: str | None = None,
313
+ target_table: str | None = None,
314
+ source_file_name: str | None = None,
315
+ target_file_name: str | None = None,
316
+ source_format: str = "Delimited",
317
+ target_format: str = "Delimited",
318
+ source_text_delimiter: str = ",",
319
+ target_text_delimiter: str = ",",
320
+ tdload_options: str | None = None,
321
+ tdload_job_name: str | None = None,
322
+ tdload_job_var_file: str | None = None,
323
+ remote_working_dir: str | None = None,
324
+ **kwargs,
325
+ ) -> None:
326
+ super().__init__(**kwargs)
327
+ self.teradata_conn_id = teradata_conn_id
328
+ self.target_teradata_conn_id = target_teradata_conn_id
329
+ self.ssh_conn_id = ssh_conn_id
330
+ self.source_table = source_table
331
+ self.select_stmt = select_stmt
332
+ self.insert_stmt = insert_stmt
333
+ self.target_table = target_table
334
+ self.source_file_name = source_file_name
335
+ self.target_file_name = target_file_name
336
+ self.source_format = source_format
337
+ self.source_text_delimiter = source_text_delimiter
338
+ self.target_format = target_format
339
+ self.target_text_delimiter = target_text_delimiter
340
+ self.tdload_options = tdload_options
341
+ self.tdload_job_name = tdload_job_name
342
+ self.tdload_job_var_file = tdload_job_var_file
343
+ self.remote_working_dir = remote_working_dir
344
+ self._src_hook: TptHook | None = None
345
+ self._dest_hook: TptHook | None = None
346
+
347
+ def execute(self, context: Context) -> int | None:
348
+ """Execute the TdLoad operation based on the configured parameters."""
349
+ # Validate parameter combinations
350
+ mode = self._validate_and_determine_mode()
351
+
352
+ # Initialize hooks
353
+ self._initialize_hooks(mode)
354
+
355
+ try:
356
+ # Prepare job variable file content if not provided
357
+ tdload_job_var_content = None
358
+ tdload_job_var_file = self.tdload_job_var_file
359
+
360
+ if not tdload_job_var_file:
361
+ tdload_job_var_content = self._prepare_job_var_content(mode)
362
+ self.log.info("Prepared tdload job variable content for mode '%s'", mode)
363
+
364
+ # Set remote working directory if SSH is used
365
+ if self._ssh_hook and not self.remote_working_dir:
366
+ self.remote_working_dir = get_remote_temp_directory(
367
+ self._ssh_hook.get_conn(), logging.getLogger(__name__)
368
+ )
369
+ # Ensure remote_working_dir is always a str
370
+ if not self.remote_working_dir:
371
+ self.remote_working_dir = "/tmp"
372
+
373
+ # Execute based on SSH availability and job var file source
374
+ return self._execute_based_on_configuration(tdload_job_var_file, tdload_job_var_content, context)
375
+
376
+ except Exception as e:
377
+ self.log.error("Failed to execute TdLoad operation in mode '%s': %s", mode, str(e))
378
+ raise
379
+
380
+ def _validate_and_determine_mode(self) -> str:
381
+ """
382
+ Validate parameters and determine the operation mode.
383
+
384
+ Returns:
385
+ A string indicating the operation mode: 'file_to_table', 'table_to_file',
386
+ 'table_to_table', or 'job_var_file'
387
+
388
+ Raises:
389
+ ValueError: If parameter combinations are invalid
390
+ """
391
+ if self.source_table and self.select_stmt:
392
+ raise ValueError(
393
+ "Both source_table and select_stmt cannot be provided simultaneously. "
394
+ "Please provide only one."
395
+ )
396
+
397
+ if self.insert_stmt and not self.target_table:
398
+ raise ValueError(
399
+ "insert_stmt is provided but target_table is not specified. "
400
+ "Please provide a target_table for the insert operation."
401
+ )
402
+
403
+ # Determine the mode of operation based on provided parameters
404
+ if self.source_file_name and self.target_table:
405
+ mode = "file_to_table"
406
+ if self.target_teradata_conn_id is None:
407
+ self.target_teradata_conn_id = self.teradata_conn_id
408
+ self.log.info(
409
+ "Loading data from file '%s' to table '%s'", self.source_file_name, self.target_table
410
+ )
411
+ elif (self.source_table or self.select_stmt) and self.target_file_name:
412
+ mode = "table_to_file"
413
+ self.log.info(
414
+ "Exporting data from %s to file '%s'",
415
+ self.source_table or "custom select statement",
416
+ self.target_file_name,
417
+ )
418
+ elif (self.source_table or self.select_stmt) and self.target_table:
419
+ mode = "table_to_table"
420
+ if self.target_teradata_conn_id is None:
421
+ raise ValueError("For table to table transfer, target_teradata_conn_id must be provided.")
422
+ self.log.info(
423
+ "Transferring data from %s to table '%s'",
424
+ self.source_table or "custom select statement",
425
+ self.target_table,
426
+ )
427
+ else:
428
+ if not self.tdload_job_var_file:
429
+ raise ValueError(
430
+ "Invalid parameter combination for the TdLoadOperator. Please provide one of these valid combinations:\n"
431
+ "1. source_file_name and target_table: to load data from a file to a table\n"
432
+ "2. source_table/select_stmt and target_file_name: to export data from a table to a file\n"
433
+ "3. source_table/select_stmt and target_table: to transfer data between tables\n"
434
+ "4. tdload_job_var_file: to use a pre-configured job variable file"
435
+ )
436
+ mode = "job_var_file"
437
+ self.log.info("Using pre-configured job variable file: %s", self.tdload_job_var_file)
438
+
439
+ return mode
440
+
441
+ def _initialize_hooks(self, mode: str) -> None:
442
+ """
443
+ Initialize the required hooks based on the operation mode.
444
+
445
+ Args:
446
+ mode: The operation mode ('file_to_table', 'table_to_file', 'table_to_table', etc.)
447
+ """
448
+ self.log.info("Initializing source connection using teradata_conn_id: %s", self.teradata_conn_id)
449
+ self._src_hook = TptHook(teradata_conn_id=self.teradata_conn_id, ssh_conn_id=self.ssh_conn_id)
450
+
451
+ if mode in ("table_to_table", "file_to_table"):
452
+ self.log.info(
453
+ "Initializing destination connection using target_teradata_conn_id: %s",
454
+ self.target_teradata_conn_id,
455
+ )
456
+ self._dest_hook = TptHook(teradata_conn_id=self.target_teradata_conn_id)
457
+
458
+ self._ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if self.ssh_conn_id else None
459
+
460
+ def _prepare_job_var_content(self, mode: str) -> str:
461
+ """
462
+ Prepare the job variable file content.
463
+
464
+ Args:
465
+ mode: The operation mode
466
+
467
+ Returns:
468
+ The prepared job variable file content as a string
469
+ """
470
+ if not self._src_hook:
471
+ raise ValueError("Source hook not initialized")
472
+
473
+ return prepare_tdload_job_var_file(
474
+ mode=mode,
475
+ source_table=self.source_table,
476
+ select_stmt=self.select_stmt,
477
+ insert_stmt=self.insert_stmt,
478
+ target_table=self.target_table,
479
+ source_file_name=self.source_file_name,
480
+ target_file_name=self.target_file_name,
481
+ source_format=self.source_format,
482
+ target_format=self.target_format,
483
+ source_text_delimiter=self.source_text_delimiter,
484
+ target_text_delimiter=self.target_text_delimiter,
485
+ source_conn=self._src_hook.get_conn(),
486
+ target_conn=self._dest_hook.get_conn() if self._dest_hook else None,
487
+ )
488
+
489
+ def _execute_based_on_configuration(
490
+ self, tdload_job_var_file: str | None, tdload_job_var_content: str | None, context: Context
491
+ ) -> int | None:
492
+ """Execute TdLoad operation based on SSH and job var file configuration."""
493
+ if self._ssh_hook:
494
+ if tdload_job_var_file:
495
+ with self._ssh_hook.get_conn() as ssh_client:
496
+ if is_valid_remote_job_var_file(
497
+ ssh_client, tdload_job_var_file, logging.getLogger(__name__)
498
+ ):
499
+ return self._handle_remote_job_var_file(
500
+ ssh_client=ssh_client,
501
+ file_path=tdload_job_var_file,
502
+ context=context,
503
+ )
504
+ raise ValueError(
505
+ f"The provided remote job variables file path '{tdload_job_var_file}' is invalid or does not exist on remote machine."
506
+ )
507
+ else:
508
+ if not self._src_hook:
509
+ raise ValueError("Source hook not initialized")
510
+ # Ensure remote_working_dir is always a str
511
+ remote_working_dir = self.remote_working_dir or "/tmp"
512
+ return self._src_hook.execute_tdload(
513
+ remote_working_dir,
514
+ tdload_job_var_content,
515
+ self.tdload_options,
516
+ self.tdload_job_name,
517
+ )
518
+ else:
519
+ if tdload_job_var_file:
520
+ if is_valid_file(tdload_job_var_file):
521
+ return self._handle_local_job_var_file(
522
+ file_path=tdload_job_var_file,
523
+ context=context,
524
+ )
525
+ raise ValueError(
526
+ f"The provided job variables file path '{tdload_job_var_file}' is invalid or does not exist."
527
+ )
528
+ if not self._src_hook:
529
+ raise ValueError("Source hook not initialized")
530
+ # Ensure remote_working_dir is always a str
531
+ remote_working_dir = self.remote_working_dir or "/tmp"
532
+ return self._src_hook.execute_tdload(
533
+ remote_working_dir,
534
+ tdload_job_var_content,
535
+ self.tdload_options,
536
+ self.tdload_job_name,
537
+ )
538
+
539
+ def _handle_remote_job_var_file(
540
+ self,
541
+ ssh_client: SSHClient,
542
+ file_path: str | None,
543
+ context: Context,
544
+ ) -> int | None:
545
+ """Handle execution using a remote job variable file."""
546
+ if not file_path:
547
+ raise ValueError("Please provide a valid job variables file path on the remote machine.")
548
+
549
+ try:
550
+ sftp = ssh_client.open_sftp()
551
+ try:
552
+ with sftp.open(file_path, "r") as remote_file:
553
+ tdload_job_var_content = remote_file.read().decode("UTF-8")
554
+ self.log.info("Successfully read remote job variable file: %s", file_path)
555
+ finally:
556
+ sftp.close()
557
+
558
+ if self._src_hook:
559
+ # Ensure remote_working_dir is always a str
560
+ remote_working_dir = self.remote_working_dir or "/tmp"
561
+ return self._src_hook._execute_tdload_via_ssh(
562
+ remote_working_dir,
563
+ tdload_job_var_content,
564
+ self.tdload_options,
565
+ self.tdload_job_name,
566
+ )
567
+ raise ValueError("Source hook not initialized for remote execution.")
568
+ except Exception as e:
569
+ self.log.error("Failed to handle remote job variable file '%s': %s", file_path, str(e))
570
+ raise
571
+
572
+ def _handle_local_job_var_file(
573
+ self,
574
+ file_path: str | None,
575
+ context: Context,
576
+ ) -> int | None:
577
+ """
578
+ Handle execution using a local job variable file.
579
+
580
+ Args:
581
+ file_path: Path to the local job variable file
582
+ context: Airflow context
583
+
584
+ Returns:
585
+ Exit code from the TdLoad operation
586
+
587
+ Raises:
588
+ ValueError: If file path is invalid or hook not initialized
589
+ """
590
+ if not file_path:
591
+ raise ValueError("Please provide a valid local job variables file path.")
592
+
593
+ if not is_valid_file(file_path):
594
+ raise ValueError(f"The job variables file path '{file_path}' is invalid or does not exist.")
595
+
596
+ try:
597
+ tdload_job_var_content = read_file(file_path, encoding="UTF-8")
598
+ self.log.info("Successfully read local job variable file: %s", file_path)
599
+
600
+ if self._src_hook:
601
+ return self._src_hook._execute_tdload_locally(
602
+ tdload_job_var_content,
603
+ self.tdload_options,
604
+ self.tdload_job_name,
605
+ )
606
+ raise ValueError("Source hook not initialized for local execution.")
607
+
608
+ except Exception as e:
609
+ self.log.error("Failed to handle local job variable file '%s': %s", file_path, str(e))
610
+ raise
611
+
612
+ def on_kill(self):
613
+ """Handle termination signals and ensure all hooks are properly cleaned up."""
614
+ self.log.info("Cleaning up TPT tdload connections on task kill")
615
+
616
+ cleanup_errors = []
617
+
618
+ # Clean up the source hook if it was initialized
619
+ if self._src_hook:
620
+ try:
621
+ self.log.info("Cleaning up source connection")
622
+ self._src_hook.on_kill()
623
+ except Exception as e:
624
+ cleanup_errors.append(f"Failed to cleanup source hook: {str(e)}")
625
+ self.log.error("Error cleaning up source connection: %s", str(e))
626
+
627
+ # Clean up the destination hook if it was initialized
628
+ if self._dest_hook:
629
+ try:
630
+ self.log.info("Cleaning up destination connection")
631
+ self._dest_hook.on_kill()
632
+ except Exception as e:
633
+ cleanup_errors.append(f"Failed to cleanup destination hook: {str(e)}")
634
+ self.log.error("Error cleaning up destination connection: %s", str(e))
635
+
636
+ # Log any cleanup errors but don't raise them during shutdown
637
+ if cleanup_errors:
638
+ self.log.warning("Some cleanup operations failed: %s", "; ".join(cleanup_errors))
639
+ else:
640
+ self.log.info("All TPT connections cleaned up successfully")
@@ -20,7 +20,7 @@ import asyncio
20
20
  from collections.abc import AsyncIterator
21
21
  from typing import Any
22
22
 
23
- from airflow.exceptions import AirflowException
23
+ from airflow.providers.common.compat.sdk import AirflowException
24
24
  from airflow.providers.common.sql.hooks.handlers import fetch_one_handler
25
25
  from airflow.providers.teradata.hooks.teradata import TeradataHook
26
26
  from airflow.providers.teradata.utils.constants import Constants
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any
25
25
  if TYPE_CHECKING:
26
26
  from paramiko import SSHClient
27
27
 
28
- from airflow.exceptions import AirflowException
28
+ from airflow.providers.common.compat.sdk import AirflowException
29
29
 
30
30
 
31
31
  def identify_os(ssh_client: SSHClient) -> str: