cloe-nessy 0.3.18__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,16 @@
1
1
  import json
2
- from typing import Any
2
+ from collections.abc import Generator
3
+ from datetime import datetime
4
+ from typing import Any, cast
3
5
 
4
- import pyspark.sql.functions as F
5
- from pyspark.sql import DataFrame
6
+ import pandas as pd
7
+ from pyspark.sql import types as T
6
8
  from requests.auth import AuthBase
9
+ from typing_extensions import TypedDict
7
10
 
8
- from cloe_nessy.clients.api_client.api_response import APIResponse
11
+ from cloe_nessy.session import DataFrame
9
12
 
10
- from ...clients.api_client import APIClient
13
+ from ...clients.api_client import APIClient, APIResponse, PaginationConfig, PaginationStrategy, PaginationStrategyType
11
14
  from ...clients.api_client.exceptions import (
12
15
  APIClientConnectionError,
13
16
  APIClientError,
@@ -17,41 +20,314 @@ from ...clients.api_client.exceptions import (
17
20
  from .reader import BaseReader
18
21
 
19
22
 
23
+ class RequestSet(TypedDict):
24
+ """The format for dynamic requests."""
25
+
26
+ endpoint: str
27
+ params: dict[str, Any]
28
+ headers: dict[str, Any] | None
29
+ data: dict[str, Any] | None
30
+ json_body: dict[str, Any] | None
31
+
32
+
33
+ class MetadataEntry(TypedDict):
34
+ """An entry for metadata."""
35
+
36
+ timestamp: str
37
+ base_url: str
38
+ url: str
39
+ status_code: int
40
+ reason: str
41
+ elapsed: float
42
+ endpoint: str
43
+ query_parameters: dict[str, str]
44
+
45
+
46
+ class ResponseMetadata(TypedDict):
47
+ """The metadata response."""
48
+
49
+ __metadata: MetadataEntry
50
+
51
+
52
+ class ResponseData(TypedDict):
53
+ """The response."""
54
+
55
+ response: str
56
+ __metadata: MetadataEntry
57
+
58
+
20
59
  class APIReader(BaseReader):
21
- """Utility class for reading an API into a DataFrame.
60
+ """Utility class for reading an API into a DataFrame with pagination support.
22
61
 
23
- This class uses an APIClient to fetch data from an API and load it into a Spark DataFrame.
62
+ This class uses an APIClient to fetch paginated data from an API and load it into a Spark DataFrame.
24
63
 
25
64
  Attributes:
26
65
  api_client: The client for making API requests.
27
66
  """
28
67
 
29
- def __init__(self, base_url: str, auth: AuthBase | None, default_headers: dict[str, str] | None = None):
68
+ OUTPUT_SCHEMA = T.StructType(
69
+ [
70
+ T.StructField(
71
+ "json_response",
72
+ T.ArrayType(
73
+ T.StructType(
74
+ [
75
+ T.StructField("response", T.StringType(), True),
76
+ T.StructField(
77
+ "__metadata",
78
+ T.StructType(
79
+ [
80
+ T.StructField("base_url", T.StringType(), True),
81
+ T.StructField("elapsed", T.DoubleType(), True),
82
+ T.StructField("reason", T.StringType(), True),
83
+ T.StructField("status_code", T.LongType(), True),
84
+ T.StructField("timestamp", T.StringType(), True),
85
+ T.StructField("url", T.StringType(), True),
86
+ T.StructField("endpoint", T.StringType(), True),
87
+ T.StructField(
88
+ "query_parameters",
89
+ T.MapType(T.StringType(), T.StringType(), True),
90
+ True,
91
+ ),
92
+ ]
93
+ ),
94
+ True,
95
+ ),
96
+ ]
97
+ )
98
+ ),
99
+ True,
100
+ )
101
+ ]
102
+ )
103
+
104
+ def __init__(
105
+ self,
106
+ base_url: str,
107
+ auth: AuthBase | None = None,
108
+ default_headers: dict[str, str] | None = None,
109
+ max_concurrent_requests: int = 8,
110
+ ):
30
111
  """Initializes the APIReader object.
31
112
 
32
113
  Args:
33
- base_url : The base URL for the API.
114
+ base_url: The base URL for the API.
34
115
  auth: The authentication method for the API.
35
116
  default_headers: Default headers to include in requests.
117
+ max_concurrent_requests: The maximum concurrent requests. Defaults to 8.
36
118
  """
37
119
  super().__init__()
38
- self.api_client = APIClient(base_url, auth, default_headers)
120
+ self.base_url = base_url
121
+ self.auth = auth
122
+ self.default_headers = default_headers
123
+ self.max_concurrent_requests = max_concurrent_requests
124
+
125
+ @staticmethod
126
+ def _get_pagination_strategy(config: PaginationConfig | dict[str, str]) -> PaginationStrategy:
127
+ """Return the appropriate pagination strategy."""
128
+ if isinstance(config, PaginationConfig):
129
+ config = config.model_dump() # PaginationStrategy expects a dict
130
+
131
+ pagination_strategy: PaginationStrategy = PaginationStrategyType[config["strategy"]].value(config)
132
+ return pagination_strategy
133
+
134
+ @staticmethod
135
+ def _get_metadata(
136
+ response: APIResponse, base_url: str, endpoint: str, params: dict[str, Any] | None = None
137
+ ) -> ResponseMetadata:
138
+ """Creates a dictionary with metadata from an APIResponse.
139
+
140
+ Creates a dictionary containing metadata related to an API response. The metadata includes the current timestamp,
141
+ the base URL of the API, the URL of the request, the HTTP status code, the reason phrase,
142
+ and the elapsed time of the request in seconds.
143
+
144
+ Args:
145
+ response: The API response object containing the metadata to be added.
146
+ base_url: The base url.
147
+ endpoint: The endpoint.
148
+ params: The parameters to be passed to the query.
149
+
150
+ Returns:
151
+ The dictionary containing metadata of API response.
152
+ """
153
+ params = params or {}
154
+ metadata: ResponseMetadata = {
155
+ "__metadata": {
156
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
157
+ "base_url": base_url,
158
+ "url": response.url,
159
+ "status_code": response.status_code,
160
+ "reason": response.reason,
161
+ "elapsed": response.elapsed.total_seconds(),
162
+ "endpoint": endpoint,
163
+ "query_parameters": params.copy(),
164
+ }
165
+ }
166
+ return metadata
167
+
168
+ @staticmethod
169
+ def _paginate(
170
+ api_client: APIClient,
171
+ endpoint: str,
172
+ method: str,
173
+ key: str | None,
174
+ params: dict[str, Any],
175
+ headers: dict[str, Any] | None,
176
+ data: dict[str, Any] | None,
177
+ json_body: dict[str, Any] | None,
178
+ timeout: int,
179
+ max_retries: int,
180
+ backoff_factor: int,
181
+ pagination_config: PaginationConfig,
182
+ ) -> Generator[ResponseData]:
183
+ """Paginates through an API endpoint based on the given pagination strategy."""
184
+ strategy = APIReader._get_pagination_strategy(pagination_config)
185
+
186
+ query_parameters = params
187
+ current_page = 1
188
+
189
+ while True:
190
+ if pagination_config.max_page != -1 and current_page > pagination_config.max_page:
191
+ break
192
+
193
+ response = api_client.request(
194
+ method=method,
195
+ endpoint=endpoint,
196
+ params=query_parameters,
197
+ headers=headers,
198
+ data=data,
199
+ json=json_body,
200
+ timeout=timeout,
201
+ max_retries=max_retries,
202
+ backoff_factor=backoff_factor,
203
+ raise_for_status=False,
204
+ )
205
+
206
+ response_data = {"response": json.dumps(response.to_dict(key))} | APIReader._get_metadata(
207
+ response, api_client.base_url, endpoint, query_parameters
208
+ )
209
+
210
+ yield cast(ResponseData, response_data)
211
+
212
+ if not strategy.has_more_data(response):
213
+ break
214
+
215
+ query_parameters = strategy.get_next_params(query_parameters)
216
+ current_page += 1
217
+
218
+ @staticmethod
219
+ def _read_from_api(
220
+ api_client: APIClient,
221
+ endpoint: str,
222
+ method: str,
223
+ key: str | None,
224
+ timeout: int,
225
+ params: dict[str, Any],
226
+ headers: dict[str, Any] | None,
227
+ data: dict[str, Any] | None,
228
+ json_body: dict[str, Any] | None,
229
+ max_retries: int,
230
+ backoff_factor: int,
231
+ ) -> list[list[ResponseData]]:
232
+ try:
233
+ response = api_client.request(
234
+ method=method,
235
+ endpoint=endpoint,
236
+ timeout=timeout,
237
+ params=params,
238
+ headers=headers,
239
+ data=data,
240
+ json=json_body,
241
+ max_retries=max_retries,
242
+ backoff_factor=backoff_factor,
243
+ )
244
+ response_data = [
245
+ [
246
+ cast(
247
+ ResponseData,
248
+ {"response": json.dumps(response.to_dict(key))}
249
+ | APIReader._get_metadata(response, api_client.base_url, endpoint, params),
250
+ )
251
+ ]
252
+ ]
253
+ return response_data
254
+
255
+ except (APIClientHTTPError, APIClientConnectionError, APIClientTimeoutError) as e:
256
+ raise RuntimeError(f"API request failed: {e}") from e
257
+ except APIClientError as e:
258
+ raise RuntimeError(f"An error occurred while reading the API data: {e}") from e
259
+ except Exception as e:
260
+ raise RuntimeError(f"An unexpected error occurred: {e}") from e
261
+
262
+ @staticmethod
263
+ def _read_from_api_with_pagination(
264
+ api_client: APIClient,
265
+ endpoint: str,
266
+ method: str,
267
+ key: str | None,
268
+ timeout: int,
269
+ params: dict[str, Any],
270
+ headers: dict[str, Any] | None,
271
+ data: dict[str, Any] | None,
272
+ json_body: dict[str, Any] | None,
273
+ pagination_config: PaginationConfig,
274
+ max_retries: int,
275
+ backoff_factor: int,
276
+ ) -> list[list[ResponseData]]:
277
+ all_data: list[list[ResponseData]] = []
278
+ all_data_temp: list[ResponseData] = []
279
+
280
+ try:
281
+ for response_data in APIReader._paginate(
282
+ api_client=api_client,
283
+ method=method,
284
+ endpoint=endpoint,
285
+ key=key,
286
+ timeout=timeout,
287
+ params=params,
288
+ headers=headers,
289
+ data=data,
290
+ json_body=json_body,
291
+ max_retries=max_retries,
292
+ backoff_factor=backoff_factor,
293
+ pagination_config=pagination_config,
294
+ ):
295
+ all_data_temp.append(response_data)
296
+ if (
297
+ len(all_data_temp) >= pagination_config.pages_per_array_limit
298
+ and pagination_config.pages_per_array_limit != -1
299
+ ):
300
+ all_data.append(all_data_temp)
301
+ all_data_temp = []
302
+
303
+ if all_data_temp:
304
+ all_data.append(all_data_temp)
305
+
306
+ return all_data
307
+
308
+ except (APIClientHTTPError, APIClientConnectionError, APIClientTimeoutError) as e:
309
+ raise RuntimeError(f"API request failed: {e}") from e
310
+ except APIClientError as e:
311
+ raise RuntimeError(f"An error occurred while reading the API data: {e}") from e
312
+ except Exception as e:
313
+ raise RuntimeError(f"An unexpected error occurred: {e}") from e
39
314
 
40
315
  def read(
41
316
  self,
42
317
  *,
43
- endpoint: str = "",
318
+ endpoint: str | None = None,
44
319
  method: str = "GET",
45
320
  key: str | None = None,
46
321
  timeout: int = 30,
47
- params: dict[str, str] | None = None,
48
- headers: dict[str, str] | None = None,
49
- data: dict[str, str] | None = None,
50
- json_body: dict[str, str] | None = None,
322
+ params: dict[str, Any] | None = None,
323
+ headers: dict[str, Any] | None = None,
324
+ data: dict[str, Any] | None = None,
325
+ json_body: dict[str, Any] | None = None,
326
+ pagination_config: PaginationConfig | None = None,
51
327
  max_retries: int = 0,
52
- options: dict[str, str] | None = None,
53
- add_metadata_column: bool = False,
54
- **kwargs: Any,
328
+ backoff_factor: int = 1,
329
+ dynamic_requests: list[RequestSet] | None = None,
330
+ **_: Any,
55
331
  ) -> DataFrame:
56
332
  """Reads data from an API endpoint and returns it as a DataFrame.
57
333
 
@@ -64,10 +340,11 @@ class APIReader(BaseReader):
64
340
  headers: The headers to include in the request.
65
341
  data: The form data to include in the request.
66
342
  json_body: The JSON data to include in the request.
343
+ pagination_config: Configuration for pagination.
67
344
  max_retries: The maximum number of retries for the request.
345
+ backoff_factor: Factor for exponential backoff between retries.
68
346
  options: Additional options for the createDataFrame function.
69
- add_metadata_column: If set, adds a __metadata column containing metadata about the API response.
70
- **kwargs: Additional keyword arguments to maintain compatibility with the base class method.
347
+ dynamic_requests: .
71
348
 
72
349
  Returns:
73
350
  DataFrame: The Spark DataFrame containing the read data in the json_object column.
@@ -75,69 +352,183 @@ class APIReader(BaseReader):
75
352
  Raises:
76
353
  RuntimeError: If there is an error with the API request or reading the data.
77
354
  """
78
- if options is None:
79
- options = {}
80
- try:
81
- response = self.api_client.request(
355
+ api_client = APIClient(
356
+ base_url=self.base_url,
357
+ auth=self.auth,
358
+ default_headers=self.default_headers,
359
+ pool_maxsize=self.max_concurrent_requests,
360
+ )
361
+
362
+ if dynamic_requests or getattr(pagination_config, "preliminary_probe", False):
363
+ if not dynamic_requests:
364
+ if not endpoint:
365
+ raise ValueError("endpoint parameter must be provided.")
366
+ dynamic_requests = [
367
+ {
368
+ "endpoint": endpoint,
369
+ "params": params or {},
370
+ "headers": headers,
371
+ "data": data,
372
+ "json_body": json_body,
373
+ }
374
+ ]
375
+
376
+ return self._read_dynamic(
377
+ api_client=api_client,
378
+ dynamic_requests=dynamic_requests,
82
379
  method=method,
380
+ key=key,
381
+ timeout=timeout,
382
+ pagination_config=pagination_config,
383
+ max_retries=max_retries,
384
+ backoff_factor=backoff_factor,
385
+ )
386
+
387
+ params = params if params is not None else {}
388
+
389
+ if not endpoint:
390
+ raise ValueError("endpoint parameter must be provided.")
391
+
392
+ if pagination_config is not None:
393
+ response_data = self._read_from_api_with_pagination(
394
+ api_client=api_client,
83
395
  endpoint=endpoint,
396
+ method=method,
397
+ key=key,
84
398
  timeout=timeout,
85
399
  params=params,
86
400
  headers=headers,
87
401
  data=data,
88
- json=json_body,
402
+ json_body=json_body,
403
+ pagination_config=pagination_config,
89
404
  max_retries=max_retries,
405
+ backoff_factor=backoff_factor,
90
406
  )
91
- data_list = response.to_dict(key)
92
- json_string = json.dumps(data_list)
93
- df: DataFrame = self._spark.createDataFrame(data={json_string}, schema=["json_string"], **options) # type: ignore
94
- row = df.select("json_string").head()
95
- if row is not None:
96
- schema = F.schema_of_json(row[0])
97
- else:
98
- raise RuntimeError("It was not possible to infer the schema of the JSON data.")
99
- df_result = df.withColumn("json_object", F.from_json("json_string", schema)).select("json_object")
100
- if add_metadata_column:
101
- df_result = self._add_metadata_column(df_result, response)
102
- return df_result
103
407
 
104
- except (APIClientHTTPError, APIClientConnectionError, APIClientTimeoutError) as e:
105
- raise RuntimeError(f"API request failed: {e}") from e
106
- except APIClientError as e:
107
- raise RuntimeError(f"An error occurred while reading the API data: {e}") from e
108
- except Exception as e:
109
- raise RuntimeError(f"An unexpected error occurred: {e}") from e
408
+ else:
409
+ response_data = self._read_from_api(
410
+ api_client=api_client,
411
+ endpoint=endpoint,
412
+ method=method,
413
+ key=key,
414
+ timeout=timeout,
415
+ params=params,
416
+ headers=headers,
417
+ data=data,
418
+ json_body=json_body,
419
+ max_retries=max_retries,
420
+ backoff_factor=backoff_factor,
421
+ )
110
422
 
111
- def _add_metadata_column(self, df: DataFrame, response: APIResponse):
112
- """Adds a metadata column to a DataFrame.
423
+ return self._spark.createDataFrame(data=[(response,) for response in response_data], schema=self.OUTPUT_SCHEMA)
113
424
 
114
- This method appends a column named `__metadata` to the given DataFrame, containing a map
115
- of metadata related to an API response. The metadata includes the current timestamp,
116
- the base URL of the API, the URL of the request, the HTTP status code, the reason phrase,
117
- and the elapsed time of the request in seconds.
425
+ def _read_dynamic(
426
+ self,
427
+ api_client: APIClient,
428
+ dynamic_requests: list[RequestSet],
429
+ method: str,
430
+ key: str | None,
431
+ timeout: int,
432
+ pagination_config: PaginationConfig | None,
433
+ max_retries: int,
434
+ backoff_factor: int,
435
+ ) -> DataFrame:
436
+ def _process_partition(pdf_iter):
437
+ for pdf in pdf_iter:
438
+ for _, row in pdf.iterrows():
439
+ endpoint = row["endpoint"]
440
+ params = row["params"] or {}
441
+ headers = row["headers"] or {}
442
+ data = row["data"] or {}
443
+ json_body = row["json_body"] or {}
118
444
 
119
- Args:
120
- df: The DataFrame to which the metadata column will be added.
121
- response: The API response object containing the metadata to be added.
445
+ if any([pagination_config is None, getattr(pagination_config, "preliminary_probe", False)]):
446
+ response_data = APIReader._read_from_api(
447
+ api_client=api_client,
448
+ endpoint=endpoint,
449
+ method=method,
450
+ key=key,
451
+ timeout=timeout,
452
+ params=params,
453
+ headers=headers,
454
+ data=data,
455
+ json_body=json_body,
456
+ max_retries=max_retries,
457
+ backoff_factor=backoff_factor,
458
+ )
459
+ else:
460
+ response_data = APIReader._read_from_api_with_pagination(
461
+ api_client=api_client,
462
+ endpoint=endpoint,
463
+ method=method,
464
+ key=key,
465
+ timeout=timeout,
466
+ params=params,
467
+ headers=headers,
468
+ data=data,
469
+ json_body=json_body,
470
+ pagination_config=pagination_config,
471
+ max_retries=max_retries,
472
+ backoff_factor=backoff_factor,
473
+ )
122
474
 
123
- Returns:
124
- DataFrame: The original DataFrame with an added `__metadata` column containing the API response metadata.
125
- """
126
- df = df.withColumn(
127
- "__metadata",
128
- F.create_map(
129
- F.lit("timestamp"),
130
- F.current_timestamp(),
131
- F.lit("base_url"),
132
- F.lit(self.api_client.base_url),
133
- F.lit("url"),
134
- F.lit(response.url),
135
- F.lit("status_code"),
136
- F.lit(response.status_code),
137
- F.lit("reason"),
138
- F.lit(response.reason),
139
- F.lit("elapsed"),
140
- F.lit(response.elapsed),
141
- ),
475
+ yield pd.DataFrame(data=[(response,) for response in response_data])
476
+
477
+ if pagination_config is not None and getattr(pagination_config, "preliminary_probe", False):
478
+ pagination_strategy = APIReader._get_pagination_strategy(pagination_config)
479
+
480
+ def make_request(
481
+ endpoint: str,
482
+ params: dict[str, Any],
483
+ headers: dict[str, Any] | None,
484
+ data: dict[str, Any] | None,
485
+ json_body: dict[str, Any] | None,
486
+ ) -> APIResponse:
487
+ return api_client.request(
488
+ method=method,
489
+ endpoint=endpoint,
490
+ params=params,
491
+ headers=headers,
492
+ data=data,
493
+ json=json_body,
494
+ timeout=timeout,
495
+ max_retries=max_retries,
496
+ backoff_factor=backoff_factor,
497
+ raise_for_status=False,
498
+ )
499
+
500
+ extended_dynamic_requests: list[RequestSet] = []
501
+ for request in dynamic_requests:
502
+ probed_params_items = pagination_strategy.probe_max_page(
503
+ **request,
504
+ make_request=make_request,
505
+ )
506
+ for probed_params_item in probed_params_items:
507
+ extended_dynamic_requests.append(
508
+ {
509
+ "endpoint": request["endpoint"],
510
+ "params": probed_params_item,
511
+ "headers": request["headers"],
512
+ "data": request["data"],
513
+ "json_body": request["json_body"],
514
+ }
515
+ )
516
+
517
+ dynamic_requests = extended_dynamic_requests
518
+
519
+ df_requests = self._spark.createDataFrame(
520
+ cast(dict, dynamic_requests),
521
+ schema="endpoint string, params map<string, string>, headers map<string, string>, data map<string, string>, json_body map<string, string>",
522
+ )
523
+
524
+ self._console_logger.info(
525
+ f"Repartitioning requests to achieve [ '{self.max_concurrent_requests}' ] concurrent requests ..."
142
526
  )
143
- return df
527
+ df_requests = df_requests.repartition(self.max_concurrent_requests)
528
+ total_requests = df_requests.count()
529
+
530
+ self._console_logger.info(f"Preparing to perform [ '{total_requests}' ] API requests in parallel ...")
531
+
532
+ df_response = df_requests.mapInPandas(_process_partition, schema=self.OUTPUT_SCHEMA)
533
+
534
+ return df_response
@@ -1,11 +1,11 @@
1
1
  from typing import Any
2
2
 
3
- from pyspark.sql import DataFrame
4
3
  from pyspark.sql.utils import AnalysisException
5
4
 
6
5
  from cloe_nessy.integration.delta_loader.delta_load_options import DeltaLoadOptions
7
6
  from cloe_nessy.integration.delta_loader.delta_loader_factory import DeltaLoaderFactory
8
7
 
8
+ from ...session import DataFrame
9
9
  from .exceptions import ReadOperationFailedError
10
10
  from .reader import BaseReader
11
11
 
@@ -47,11 +47,13 @@ class CatalogReader(BaseReader):
47
47
  if options is None:
48
48
  options = {}
49
49
  if not table_identifier:
50
- raise ValueError("table_identifier is required")
50
+ raise ValueError("table_identifier is required.")
51
51
  if not isinstance(table_identifier, str):
52
- raise ValueError("table_identifier must be a string")
52
+ raise ValueError("table_identifier must be a string.")
53
53
  if len(table_identifier.split(".")) != 3:
54
- raise ValueError("table_identifier must be in the format 'catalog.schema.table'")
54
+ raise ValueError("table_identifier must be in the format 'catalog.schema.table'.")
55
+
56
+ options = options or {}
55
57
 
56
58
  try:
57
59
  if delta_load_options:
@@ -3,7 +3,8 @@ from typing import Any
3
3
 
4
4
  import pandas as pd
5
5
  import pyspark.sql.functions as F
6
- from pyspark.sql import DataFrame
6
+
7
+ from cloe_nessy.session import DataFrame
7
8
 
8
9
  from .reader import BaseReader
9
10
 
@@ -27,7 +28,6 @@ class ExcelDataFrameReader(BaseReader):
27
28
  def read(
28
29
  self,
29
30
  location: str,
30
- *,
31
31
  sheet_name: str | int | list = 0,
32
32
  header: int | list[int] = 0,
33
33
  index_col: int | list[int] | None = None,
@@ -43,7 +43,7 @@ class ExcelDataFrameReader(BaseReader):
43
43
  options: dict | None = None,
44
44
  load_as_strings: bool = False,
45
45
  add_metadata_column: bool = False,
46
- **kwargs: Any,
46
+ **_: Any,
47
47
  ) -> DataFrame:
48
48
  """Reads Excel file on specified location and returns DataFrame.
49
49
 
@@ -1,10 +1,12 @@
1
1
  from typing import Any
2
2
 
3
3
  import pyspark.sql.functions as F
4
- from pyspark.sql import DataFrame, DataFrameReader
4
+ from pyspark.sql import DataFrameReader
5
5
  from pyspark.sql.streaming import DataStreamReader
6
6
  from pyspark.sql.types import StructType
7
7
 
8
+ from cloe_nessy.session import DataFrame
9
+
8
10
  from ...file_utilities import get_file_paths
9
11
  from ..delta_loader.delta_load_options import DeltaLoadOptions
10
12
  from ..delta_loader.delta_loader_factory import DeltaLoaderFactory
@@ -1,7 +1,7 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Any
3
3
 
4
- from pyspark.sql import DataFrame, SparkSession
4
+ from cloe_nessy.session import DataFrame, SparkSession
5
5
 
6
6
  from ...logging.logger_mixin import LoggerMixin
7
7
  from ...session import SessionManager
@@ -1,4 +1,4 @@
1
- from pyspark.sql import DataFrame
1
+ from cloe_nessy.session import DataFrame
2
2
 
3
3
 
4
4
  class CatalogWriter: