aio-sf 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aio_salesforce/__init__.py +27 -0
- aio_salesforce/api/README.md +107 -0
- aio_salesforce/api/__init__.py +65 -0
- aio_salesforce/api/bulk_v2/__init__.py +21 -0
- aio_salesforce/api/bulk_v2/client.py +200 -0
- aio_salesforce/api/bulk_v2/types.py +71 -0
- aio_salesforce/api/describe/__init__.py +31 -0
- aio_salesforce/api/describe/client.py +94 -0
- aio_salesforce/api/describe/types.py +303 -0
- aio_salesforce/api/query/__init__.py +18 -0
- aio_salesforce/api/query/client.py +216 -0
- aio_salesforce/api/query/types.py +38 -0
- aio_salesforce/api/types.py +303 -0
- aio_salesforce/connection.py +511 -0
- aio_salesforce/exporter/__init__.py +38 -0
- aio_salesforce/exporter/bulk_export.py +397 -0
- aio_salesforce/exporter/parquet_writer.py +296 -0
- aio_salesforce/exporter/parquet_writer.py.backup +326 -0
- aio_sf-0.1.0b1.dist-info/METADATA +198 -0
- aio_sf-0.1.0b1.dist-info/RECORD +22 -0
- aio_sf-0.1.0b1.dist-info/WHEEL +4 -0
- aio_sf-0.1.0b1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Generator, Optional
|
|
3
|
+
import csv
|
|
4
|
+
import asyncio
|
|
5
|
+
|
|
6
|
+
from ..api.describe.types import FieldInfo
|
|
7
|
+
from ..connection import SalesforceConnection
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QueryResult:
|
|
11
|
+
"""
|
|
12
|
+
A query result that supports len() and acts as an iterator over individual records.
|
|
13
|
+
Can be created from a completed job or resumed from a locator.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
sf: SalesforceConnection,
|
|
19
|
+
job_id: str,
|
|
20
|
+
total_records: Optional[int] = None,
|
|
21
|
+
query_locator: Optional[str] = None,
|
|
22
|
+
batch_size: int = 10000,
|
|
23
|
+
api_version: Optional[str] = None,
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
Initialize QueryResult.
|
|
27
|
+
|
|
28
|
+
:param sf: Salesforce connection instance
|
|
29
|
+
:param job_id: Salesforce job ID
|
|
30
|
+
:param total_records: Total number of records (None if unknown, e.g., when resuming)
|
|
31
|
+
:param query_locator: Starting locator (None to start from beginning)
|
|
32
|
+
:param batch_size: Number of records to fetch per batch
|
|
33
|
+
:param api_version: Salesforce API version (defaults to connection version)
|
|
34
|
+
"""
|
|
35
|
+
self._sf = sf
|
|
36
|
+
self._job_id = job_id
|
|
37
|
+
self._total_records = total_records
|
|
38
|
+
self._query_locator = query_locator
|
|
39
|
+
self._batch_size = batch_size
|
|
40
|
+
self._api_version = api_version or sf.version
|
|
41
|
+
|
|
42
|
+
def __iter__(self):
|
|
43
|
+
"""Return the generator that yields individual records."""
|
|
44
|
+
# For backward compatibility, we'll collect all records in a blocking manner
|
|
45
|
+
# This is not ideal for large datasets, but maintains the existing API
|
|
46
|
+
try:
|
|
47
|
+
loop = asyncio.get_running_loop()
|
|
48
|
+
# If we're in an async context, we can't block
|
|
49
|
+
raise RuntimeError(
|
|
50
|
+
"Cannot iterate QueryResult synchronously when an async event loop is already running. "
|
|
51
|
+
"Use 'async for record in query_result.aiter()' instead."
|
|
52
|
+
)
|
|
53
|
+
except RuntimeError as e:
|
|
54
|
+
if "Cannot iterate" in str(e):
|
|
55
|
+
raise e
|
|
56
|
+
# No event loop is running, we can create one
|
|
57
|
+
return asyncio.run(self._collect_all_records())
|
|
58
|
+
|
|
59
|
+
async def aiter(self):
|
|
60
|
+
"""Async iterator that yields individual records."""
|
|
61
|
+
async for record in self._generate_records():
|
|
62
|
+
yield record
|
|
63
|
+
|
|
64
|
+
async def _collect_all_records(self):
|
|
65
|
+
"""Collect all records into a list for synchronous iteration."""
|
|
66
|
+
records = []
|
|
67
|
+
async for record in self._generate_records():
|
|
68
|
+
records.append(record)
|
|
69
|
+
return iter(records)
|
|
70
|
+
|
|
71
|
+
def __len__(self) -> int:
|
|
72
|
+
"""Return the total number of records."""
|
|
73
|
+
if self._total_records is None:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"Total record count is not available (likely resumed from locator)"
|
|
76
|
+
)
|
|
77
|
+
return self._total_records
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def total_records(self) -> Optional[int]:
|
|
81
|
+
"""Get the total number of records (None if unknown)."""
|
|
82
|
+
return self._total_records
|
|
83
|
+
|
|
84
|
+
def has_total_count(self) -> bool:
|
|
85
|
+
"""Check if total record count is available."""
|
|
86
|
+
return self._total_records is not None
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def job_id(self) -> str:
|
|
90
|
+
"""Get the job ID."""
|
|
91
|
+
return self._job_id
|
|
92
|
+
|
|
93
|
+
def resume_from_locator(self, locator: str) -> "QueryResult":
|
|
94
|
+
"""
|
|
95
|
+
Create a new QueryResult starting from the given locator.
|
|
96
|
+
|
|
97
|
+
:param locator: The locator to resume from
|
|
98
|
+
:returns: New QueryResult instance
|
|
99
|
+
"""
|
|
100
|
+
return QueryResult(
|
|
101
|
+
sf=self._sf,
|
|
102
|
+
job_id=self._job_id,
|
|
103
|
+
total_records=None, # Unknown when resuming
|
|
104
|
+
query_locator=locator,
|
|
105
|
+
batch_size=self._batch_size,
|
|
106
|
+
api_version=self._api_version,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _stream_csv_to_records(
|
|
110
|
+
self, response_text: str
|
|
111
|
+
) -> Generator[Dict[str, Any], None, None]:
|
|
112
|
+
"""
|
|
113
|
+
Stream CSV response and convert to record dictionaries.
|
|
114
|
+
|
|
115
|
+
:param response_text: CSV response text
|
|
116
|
+
:yields: Individual record dictionaries
|
|
117
|
+
"""
|
|
118
|
+
lines = response_text.splitlines()
|
|
119
|
+
|
|
120
|
+
# Get the header row first
|
|
121
|
+
if not lines:
|
|
122
|
+
# No data in this batch
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
header_line = lines[0]
|
|
127
|
+
fieldnames = next(csv.reader([header_line]))
|
|
128
|
+
except (IndexError, StopIteration, csv.Error):
|
|
129
|
+
# No data in this batch
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
# Process each data row
|
|
133
|
+
for line in lines[1:]:
|
|
134
|
+
if line.strip(): # Skip empty lines
|
|
135
|
+
try:
|
|
136
|
+
# Parse the CSV row
|
|
137
|
+
row_values = next(csv.reader([line]))
|
|
138
|
+
# Convert to dictionary
|
|
139
|
+
row = dict(zip(fieldnames, row_values))
|
|
140
|
+
yield row
|
|
141
|
+
except (csv.Error, StopIteration):
|
|
142
|
+
logging.warning(f"Error parsing line: {line}")
|
|
143
|
+
# Skip malformed lines
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
async def _generate_records(self):
|
|
147
|
+
"""Async generator that yields individual records."""
|
|
148
|
+
locator = self._query_locator
|
|
149
|
+
ctn = 0
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
while True:
|
|
153
|
+
|
|
154
|
+
# Use the bulk_v2 API to get results
|
|
155
|
+
response_text, next_locator = await self._sf.bulk_v2.get_job_results(
|
|
156
|
+
job_id=self._job_id,
|
|
157
|
+
locator=locator,
|
|
158
|
+
max_records=self._batch_size,
|
|
159
|
+
api_version=self._api_version,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
for record in self._stream_csv_to_records(response_text):
|
|
163
|
+
ctn += 1
|
|
164
|
+
yield record
|
|
165
|
+
|
|
166
|
+
# setup next locator
|
|
167
|
+
locator = next_locator
|
|
168
|
+
|
|
169
|
+
if not locator:
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise Exception(
|
|
174
|
+
f"Error processing record {ctn}: {e}. Current Query Locator: {locator}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
async def _wait_for_job_completion(
|
|
179
|
+
sf: SalesforceConnection,
|
|
180
|
+
job_id: str,
|
|
181
|
+
api_version: str,
|
|
182
|
+
poll_interval: int,
|
|
183
|
+
timeout: Optional[int],
|
|
184
|
+
) -> int:
|
|
185
|
+
"""
|
|
186
|
+
Wait for a Salesforce bulk job to complete and return the total record count.
|
|
187
|
+
|
|
188
|
+
:param sf: Salesforce connection instance
|
|
189
|
+
:param job_id: Job ID to monitor
|
|
190
|
+
:param api_version: API version to use
|
|
191
|
+
:param poll_interval: Time in seconds between status checks
|
|
192
|
+
:param timeout: Maximum time to wait (None for no timeout)
|
|
193
|
+
:returns: Total number of records processed
|
|
194
|
+
:raises TimeoutError: If job doesn't complete within timeout
|
|
195
|
+
:raises Exception: If job fails
|
|
196
|
+
"""
|
|
197
|
+
# Use the new bulk_v2 API
|
|
198
|
+
job_status = await sf.bulk_v2.wait_for_job_completion(
|
|
199
|
+
job_id=job_id,
|
|
200
|
+
poll_interval=poll_interval,
|
|
201
|
+
timeout=timeout,
|
|
202
|
+
api_version=api_version,
|
|
203
|
+
)
|
|
204
|
+
return job_status.get("numberRecordsProcessed", 0)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def bulk_query(
|
|
208
|
+
sf: SalesforceConnection,
|
|
209
|
+
soql_query: Optional[str],
|
|
210
|
+
all_rows: bool = False,
|
|
211
|
+
existing_job_id: Optional[str] = None,
|
|
212
|
+
query_locator: Optional[str] = None,
|
|
213
|
+
batch_size: int = 10000,
|
|
214
|
+
api_version: Optional[str] = None,
|
|
215
|
+
poll_interval: int = 5,
|
|
216
|
+
timeout: Optional[int] = None,
|
|
217
|
+
) -> QueryResult:
|
|
218
|
+
"""
|
|
219
|
+
Executes a Salesforce query via the BULK2 API and returns a QueryResult.
|
|
220
|
+
|
|
221
|
+
:param sf: A SalesforceConnection instance containing access_token, instance_url, and version.
|
|
222
|
+
:param soql_query: The SOQL query string to execute.
|
|
223
|
+
:param all_rows: If True, includes deleted and archived records.
|
|
224
|
+
:param existing_job_id: Use an existing batch ID to continue processing.
|
|
225
|
+
:param query_locator: Use an existing query locator to continue processing.
|
|
226
|
+
:param batch_size: Number of records to fetch per batch.
|
|
227
|
+
:param api_version: Salesforce API version to use (defaults to connection version).
|
|
228
|
+
:param poll_interval: Time in seconds between job status checks.
|
|
229
|
+
:param timeout: Maximum time in seconds to wait for job completion (None = no timeout).
|
|
230
|
+
:returns: QueryResult that can be iterated over and supports len().
|
|
231
|
+
"""
|
|
232
|
+
if not soql_query and not existing_job_id:
|
|
233
|
+
raise ValueError("SOQL query or existing job ID must be provided")
|
|
234
|
+
|
|
235
|
+
if query_locator and not existing_job_id:
|
|
236
|
+
raise ValueError("query_locator may only be used with an existing job ID")
|
|
237
|
+
|
|
238
|
+
# Use connection version if no api_version specified
|
|
239
|
+
effective_api_version = api_version or sf.version
|
|
240
|
+
|
|
241
|
+
# Step 1: Create the job (if needed)
|
|
242
|
+
if existing_job_id:
|
|
243
|
+
job_id = existing_job_id
|
|
244
|
+
logging.info(f"Using existing job id: {job_id}")
|
|
245
|
+
elif soql_query:
|
|
246
|
+
# Use the new bulk_v2 API to create the job
|
|
247
|
+
job_info = await sf.bulk_v2.create_job(
|
|
248
|
+
soql_query=soql_query,
|
|
249
|
+
all_rows=all_rows,
|
|
250
|
+
api_version=effective_api_version,
|
|
251
|
+
)
|
|
252
|
+
job_id = job_info["id"]
|
|
253
|
+
else:
|
|
254
|
+
raise ValueError("SOQL query or existing job ID must be provided")
|
|
255
|
+
|
|
256
|
+
# Step 2: Wait for the job to complete
|
|
257
|
+
total_records = await _wait_for_job_completion(
|
|
258
|
+
sf, job_id, effective_api_version, poll_interval, timeout
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Step 3: Return QueryResult that manages its own data fetching
|
|
262
|
+
return QueryResult(
|
|
263
|
+
sf=sf,
|
|
264
|
+
job_id=job_id,
|
|
265
|
+
total_records=total_records,
|
|
266
|
+
query_locator=query_locator,
|
|
267
|
+
batch_size=batch_size,
|
|
268
|
+
api_version=effective_api_version,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def resume_from_locator(
|
|
273
|
+
sf: SalesforceConnection,
|
|
274
|
+
job_id: str,
|
|
275
|
+
locator: str,
|
|
276
|
+
batch_size: int = 10000,
|
|
277
|
+
api_version: Optional[str] = None,
|
|
278
|
+
) -> QueryResult:
|
|
279
|
+
"""
|
|
280
|
+
Resume a bulk query from a locator. Useful when you only have a locator and job_id.
|
|
281
|
+
|
|
282
|
+
:param sf: Salesforce connection instance
|
|
283
|
+
:param job_id: Salesforce job ID
|
|
284
|
+
:param locator: Query locator to resume from
|
|
285
|
+
:param batch_size: Number of records to fetch per batch
|
|
286
|
+
:param api_version: Salesforce API version (defaults to connection version)
|
|
287
|
+
:returns: QueryResult that can be iterated over (len() will raise error since total is unknown)
|
|
288
|
+
"""
|
|
289
|
+
return QueryResult(
|
|
290
|
+
sf=sf,
|
|
291
|
+
job_id=job_id,
|
|
292
|
+
total_records=None, # Unknown when resuming
|
|
293
|
+
query_locator=locator,
|
|
294
|
+
batch_size=batch_size,
|
|
295
|
+
api_version=api_version,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# Helper function to get all fields that can be queried by bulk API
|
|
300
|
+
async def get_bulk_fields(
|
|
301
|
+
sf: SalesforceConnection, object_type: str, api_version: Optional[str] = None
|
|
302
|
+
) -> List[FieldInfo]:
|
|
303
|
+
"""Get field metadata for queryable fields in a Salesforce object.
|
|
304
|
+
|
|
305
|
+
:param sf: Salesforce connection instance
|
|
306
|
+
:param object_type: Name of the Salesforce object (e.g., 'Account', 'Contact')
|
|
307
|
+
:param api_version: API version to use (defaults to connection version)
|
|
308
|
+
:returns: List of field metadata dictionaries for queryable fields
|
|
309
|
+
"""
|
|
310
|
+
# Use the metadata API to get object description
|
|
311
|
+
describe_data = await sf.metadata.describe_sobject(object_type, api_version)
|
|
312
|
+
fields_metadata = describe_data["fields"]
|
|
313
|
+
|
|
314
|
+
# Create a set of all compound field names to exclude
|
|
315
|
+
compound_field_names = {
|
|
316
|
+
field.get("compoundFieldName")
|
|
317
|
+
for field in fields_metadata
|
|
318
|
+
if field.get("compoundFieldName")
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
# Filter to only queryable fields that aren't compound fields
|
|
322
|
+
queryable_fields = [
|
|
323
|
+
field
|
|
324
|
+
for field in fields_metadata
|
|
325
|
+
if field.get("name") not in compound_field_names
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
return queryable_fields
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def write_records_to_csv(
|
|
332
|
+
query_result: QueryResult,
|
|
333
|
+
file_path: str,
|
|
334
|
+
encoding: str = "utf-8",
|
|
335
|
+
delimiter: str = ",",
|
|
336
|
+
):
|
|
337
|
+
"""
|
|
338
|
+
Write records from a QueryResult to a CSV file.
|
|
339
|
+
|
|
340
|
+
:param query_result: QueryResult object yielding individual records
|
|
341
|
+
:param file_path: Path to the output CSV file
|
|
342
|
+
:param encoding: File encoding (default: utf-8)
|
|
343
|
+
:param delimiter: CSV delimiter (default: comma)
|
|
344
|
+
"""
|
|
345
|
+
with open(file_path, "w", newline="", encoding=encoding) as csvfile:
|
|
346
|
+
writer = None
|
|
347
|
+
|
|
348
|
+
for record in query_result:
|
|
349
|
+
# Initialize writer with fieldnames from first record
|
|
350
|
+
if writer is None:
|
|
351
|
+
fieldnames = record.keys()
|
|
352
|
+
writer = csv.DictWriter(
|
|
353
|
+
csvfile, fieldnames=fieldnames, delimiter=delimiter
|
|
354
|
+
)
|
|
355
|
+
writer.writeheader()
|
|
356
|
+
|
|
357
|
+
writer.writerow(record)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def batch_records(query_result: QueryResult, batch_size: int = 1000):
|
|
361
|
+
"""
|
|
362
|
+
Convert individual records into batches for bulk operations.
|
|
363
|
+
|
|
364
|
+
:param query_result: QueryResult object yielding individual records
|
|
365
|
+
:param batch_size: Number of records per batch
|
|
366
|
+
:yields: Lists of records (batches)
|
|
367
|
+
"""
|
|
368
|
+
batch = []
|
|
369
|
+
for record in query_result:
|
|
370
|
+
batch.append(record)
|
|
371
|
+
if len(batch) >= batch_size:
|
|
372
|
+
yield batch
|
|
373
|
+
batch = []
|
|
374
|
+
|
|
375
|
+
# Yield any remaining records
|
|
376
|
+
if batch:
|
|
377
|
+
yield batch
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
async def batch_records_async(query_result: QueryResult, batch_size: int = 1000):
|
|
381
|
+
"""
|
|
382
|
+
Convert individual records into batches for bulk operations (async version).
|
|
383
|
+
|
|
384
|
+
:param query_result: QueryResult object yielding individual records
|
|
385
|
+
:param batch_size: Number of records per batch
|
|
386
|
+
:yields: Lists of records (batches)
|
|
387
|
+
"""
|
|
388
|
+
batch = []
|
|
389
|
+
async for record in query_result.aiter():
|
|
390
|
+
batch.append(record)
|
|
391
|
+
if len(batch) >= batch_size:
|
|
392
|
+
yield batch
|
|
393
|
+
batch = []
|
|
394
|
+
|
|
395
|
+
# Yield any remaining records
|
|
396
|
+
if batch:
|
|
397
|
+
yield batch
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parquet writer module for converting Salesforce QueryResult to Parquet format.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import pyarrow as pa
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import pyarrow.parquet as pq
|
|
11
|
+
|
|
12
|
+
from ..api.describe.types import FieldInfo
|
|
13
|
+
|
|
14
|
+
from .bulk_export import QueryResult, batch_records, batch_records_async
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def salesforce_to_arrow_type(sf_type: str) -> pa.DataType:
|
|
18
|
+
"""Convert Salesforce data types to Arrow data types."""
|
|
19
|
+
type_mapping = {
|
|
20
|
+
"string": pa.string(),
|
|
21
|
+
"boolean": pa.bool_(),
|
|
22
|
+
"int": pa.int64(),
|
|
23
|
+
"double": pa.float64(),
|
|
24
|
+
"date": pa.string(), # Store as string since SF returns ISO format
|
|
25
|
+
"datetime": pa.string(), # Store as string since SF returns ISO format
|
|
26
|
+
"currency": pa.float64(),
|
|
27
|
+
"reference": pa.string(),
|
|
28
|
+
"picklist": pa.string(),
|
|
29
|
+
"multipicklist": pa.string(),
|
|
30
|
+
"textarea": pa.string(),
|
|
31
|
+
"phone": pa.string(),
|
|
32
|
+
"url": pa.string(),
|
|
33
|
+
"email": pa.string(),
|
|
34
|
+
"combobox": pa.string(),
|
|
35
|
+
"percent": pa.float64(),
|
|
36
|
+
"id": pa.string(),
|
|
37
|
+
"base64": pa.string(),
|
|
38
|
+
"anyType": pa.string(),
|
|
39
|
+
}
|
|
40
|
+
return type_mapping.get(sf_type.lower(), pa.string())
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_schema_from_metadata(fields_metadata: List[FieldInfo]) -> pa.Schema:
|
|
44
|
+
"""
|
|
45
|
+
Create a PyArrow schema from Salesforce field metadata.
|
|
46
|
+
|
|
47
|
+
:param fields_metadata: List of field metadata dictionaries from Salesforce
|
|
48
|
+
:returns: PyArrow schema
|
|
49
|
+
"""
|
|
50
|
+
arrow_fields = []
|
|
51
|
+
for field in fields_metadata:
|
|
52
|
+
field_name = field.get("name", "").lower() # Normalize to lowercase
|
|
53
|
+
sf_type = field.get("type", "string")
|
|
54
|
+
arrow_type = salesforce_to_arrow_type(sf_type)
|
|
55
|
+
# All fields are nullable since Salesforce can return empty values
|
|
56
|
+
arrow_fields.append(pa.field(field_name, arrow_type, nullable=True))
|
|
57
|
+
|
|
58
|
+
return pa.schema(arrow_fields)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ParquetWriter:
|
|
62
|
+
"""
|
|
63
|
+
Writer class for converting Salesforce QueryResult to Parquet format.
|
|
64
|
+
Supports streaming writes and optional schema from field metadata.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
file_path: str,
|
|
70
|
+
schema: Optional[pa.Schema] = None,
|
|
71
|
+
batch_size: int = 10000,
|
|
72
|
+
convert_empty_to_null: bool = True,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Initialize ParquetWriter.
|
|
76
|
+
|
|
77
|
+
:param file_path: Path to output parquet file
|
|
78
|
+
:param schema: Optional PyArrow schema. If None, will be inferred from first batch
|
|
79
|
+
:param batch_size: Number of records to process in each batch
|
|
80
|
+
:param convert_empty_to_null: Convert empty strings to null values
|
|
81
|
+
"""
|
|
82
|
+
self.file_path = file_path
|
|
83
|
+
self.schema = schema
|
|
84
|
+
self.batch_size = batch_size
|
|
85
|
+
self.convert_empty_to_null = convert_empty_to_null
|
|
86
|
+
self._writer = None
|
|
87
|
+
self._schema_finalized = False
|
|
88
|
+
|
|
89
|
+
# Ensure parent directory exists
|
|
90
|
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
91
|
+
|
|
92
|
+
def write_query_result(self, query_result: QueryResult) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Write all records from a QueryResult to the parquet file.
|
|
95
|
+
|
|
96
|
+
:param query_result: QueryResult to write
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
for batch in batch_records(query_result, self.batch_size):
|
|
100
|
+
self._write_batch(batch)
|
|
101
|
+
finally:
|
|
102
|
+
self.close()
|
|
103
|
+
|
|
104
|
+
async def write_query_result_async(self, query_result: QueryResult) -> None:
|
|
105
|
+
"""
|
|
106
|
+
Write all records from a QueryResult to the parquet file (async version).
|
|
107
|
+
|
|
108
|
+
:param query_result: QueryResult to write
|
|
109
|
+
"""
|
|
110
|
+
try:
|
|
111
|
+
async for batch in batch_records_async(query_result, self.batch_size):
|
|
112
|
+
self._write_batch(batch)
|
|
113
|
+
finally:
|
|
114
|
+
self.close()
|
|
115
|
+
|
|
116
|
+
def _write_batch(self, batch: List[Dict[str, Any]]) -> None:
|
|
117
|
+
"""Write a batch of records to the parquet file."""
|
|
118
|
+
if not batch:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
# Convert field names to lowercase for consistency
|
|
122
|
+
converted_batch = []
|
|
123
|
+
for record in batch:
|
|
124
|
+
converted_record = {k.lower(): v for k, v in record.items()}
|
|
125
|
+
converted_batch.append(converted_record)
|
|
126
|
+
|
|
127
|
+
# Create DataFrame
|
|
128
|
+
df = pd.DataFrame(converted_batch)
|
|
129
|
+
|
|
130
|
+
# If schema not finalized, create it from first batch
|
|
131
|
+
if not self._schema_finalized:
|
|
132
|
+
if self.schema is None:
|
|
133
|
+
self.schema = self._infer_schema_from_dataframe(df)
|
|
134
|
+
else:
|
|
135
|
+
# Filter schema to only include fields that are actually in the data
|
|
136
|
+
self.schema = self._filter_schema_to_data(self.schema, df.columns)
|
|
137
|
+
self._schema_finalized = True
|
|
138
|
+
|
|
139
|
+
# Apply data type conversions based on schema
|
|
140
|
+
self._convert_dataframe_types(df)
|
|
141
|
+
|
|
142
|
+
# Create Arrow table
|
|
143
|
+
table = pa.Table.from_pandas(df, schema=self.schema)
|
|
144
|
+
|
|
145
|
+
# Initialize writer if needed
|
|
146
|
+
if self._writer is None:
|
|
147
|
+
self._writer = pq.ParquetWriter(self.file_path, self.schema)
|
|
148
|
+
|
|
149
|
+
# Write the table
|
|
150
|
+
self._writer.write_table(table)
|
|
151
|
+
|
|
152
|
+
def _infer_schema_from_dataframe(self, df: pd.DataFrame) -> pa.Schema:
|
|
153
|
+
"""Infer schema from the first DataFrame."""
|
|
154
|
+
fields = []
|
|
155
|
+
for col_name, dtype in df.dtypes.items():
|
|
156
|
+
if dtype == "object":
|
|
157
|
+
arrow_type = pa.string()
|
|
158
|
+
elif dtype == "bool":
|
|
159
|
+
arrow_type = pa.bool_()
|
|
160
|
+
elif dtype in ["int64", "int32"]:
|
|
161
|
+
arrow_type = pa.int64()
|
|
162
|
+
elif dtype in ["float64", "float32"]:
|
|
163
|
+
arrow_type = pa.float64()
|
|
164
|
+
else:
|
|
165
|
+
arrow_type = pa.string()
|
|
166
|
+
|
|
167
|
+
fields.append(pa.field(col_name, arrow_type, nullable=True))
|
|
168
|
+
|
|
169
|
+
return pa.schema(fields)
|
|
170
|
+
|
|
171
|
+
def _filter_schema_to_data(
|
|
172
|
+
self, schema: pa.Schema, data_columns: List[str]
|
|
173
|
+
) -> pa.Schema:
|
|
174
|
+
"""Filter schema to only include fields that are present in the data."""
|
|
175
|
+
# Convert data columns to set for faster lookup
|
|
176
|
+
data_columns_set = set(data_columns)
|
|
177
|
+
|
|
178
|
+
# Filter schema fields to only those present in data
|
|
179
|
+
filtered_fields = []
|
|
180
|
+
for field in schema:
|
|
181
|
+
if field.name in data_columns_set:
|
|
182
|
+
filtered_fields.append(field)
|
|
183
|
+
|
|
184
|
+
if len(filtered_fields) != len(data_columns_set):
|
|
185
|
+
# Log fields that are in data but not in schema (shouldn't happen normally)
|
|
186
|
+
missing_in_schema = data_columns_set - {f.name for f in filtered_fields}
|
|
187
|
+
if missing_in_schema:
|
|
188
|
+
logging.warning(
|
|
189
|
+
f"Fields in data but not in schema: {missing_in_schema}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return pa.schema(filtered_fields)
|
|
193
|
+
|
|
194
|
+
def _convert_dataframe_types(self, df: pd.DataFrame) -> None:
|
|
195
|
+
"""Convert DataFrame types based on the schema."""
|
|
196
|
+
for field in self.schema:
|
|
197
|
+
field_name = field.name
|
|
198
|
+
if field_name not in df.columns:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Convert empty strings to null if requested
|
|
202
|
+
if self.convert_empty_to_null:
|
|
203
|
+
df[field_name] = df[field_name].replace({"": None})
|
|
204
|
+
|
|
205
|
+
# Apply type-specific conversions
|
|
206
|
+
if pa.types.is_boolean(field.type):
|
|
207
|
+
# Convert string 'true'/'false' to boolean
|
|
208
|
+
df[field_name] = (
|
|
209
|
+
df[field_name]
|
|
210
|
+
.map({"true": True, "false": False, None: None})
|
|
211
|
+
.fillna(df[field_name])
|
|
212
|
+
) # Keep original values for non-string booleans
|
|
213
|
+
elif pa.types.is_integer(field.type):
|
|
214
|
+
df[field_name] = pd.to_numeric(df[field_name], errors="coerce").astype(
|
|
215
|
+
"Int64"
|
|
216
|
+
) # Nullable integer
|
|
217
|
+
elif pa.types.is_floating(field.type):
|
|
218
|
+
df[field_name] = pd.to_numeric(df[field_name], errors="coerce")
|
|
219
|
+
|
|
220
|
+
# Replace empty strings with None for non-string fields
|
|
221
|
+
if not pa.types.is_string(field.type):
|
|
222
|
+
df[field_name] = df[field_name].replace("", pd.NA)
|
|
223
|
+
|
|
224
|
+
def close(self) -> None:
|
|
225
|
+
"""Close the parquet writer."""
|
|
226
|
+
if self._writer:
|
|
227
|
+
self._writer.close()
|
|
228
|
+
self._writer = None
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def write_query_to_parquet(
|
|
232
|
+
query_result: QueryResult,
|
|
233
|
+
file_path: str,
|
|
234
|
+
fields_metadata: Optional[List[FieldInfo]] = None,
|
|
235
|
+
schema: Optional[pa.Schema] = None,
|
|
236
|
+
batch_size: int = 10000,
|
|
237
|
+
convert_empty_to_null: bool = True,
|
|
238
|
+
) -> None:
|
|
239
|
+
"""
|
|
240
|
+
Convenience function to write a QueryResult to a parquet file.
|
|
241
|
+
|
|
242
|
+
:param query_result: QueryResult to write
|
|
243
|
+
:param file_path: Path to output parquet file
|
|
244
|
+
:param fields_metadata: Optional Salesforce field metadata for schema creation
|
|
245
|
+
:param schema: Optional pre-created PyArrow schema (takes precedence over fields_metadata)
|
|
246
|
+
:param batch_size: Number of records to process in each batch
|
|
247
|
+
:param convert_empty_to_null: Convert empty strings to null values
|
|
248
|
+
"""
|
|
249
|
+
effective_schema = None
|
|
250
|
+
if schema:
|
|
251
|
+
effective_schema = schema
|
|
252
|
+
elif fields_metadata:
|
|
253
|
+
effective_schema = create_schema_from_metadata(fields_metadata)
|
|
254
|
+
|
|
255
|
+
writer = ParquetWriter(
|
|
256
|
+
file_path=file_path,
|
|
257
|
+
schema=effective_schema,
|
|
258
|
+
batch_size=batch_size,
|
|
259
|
+
convert_empty_to_null=convert_empty_to_null,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
writer.write_query_result(query_result)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
async def write_query_to_parquet_async(
|
|
266
|
+
query_result: QueryResult,
|
|
267
|
+
file_path: str,
|
|
268
|
+
fields_metadata: Optional[List[Dict[str, Any]]] = None,
|
|
269
|
+
schema: Optional[pa.Schema] = None,
|
|
270
|
+
batch_size: int = 10000,
|
|
271
|
+
convert_empty_to_null: bool = True,
|
|
272
|
+
) -> None:
|
|
273
|
+
"""
|
|
274
|
+
Convenience function to write a QueryResult to a parquet file (async version).
|
|
275
|
+
|
|
276
|
+
:param query_result: QueryResult to write
|
|
277
|
+
:param file_path: Path to output parquet file
|
|
278
|
+
:param fields_metadata: Optional Salesforce field metadata for schema creation
|
|
279
|
+
:param schema: Optional pre-created PyArrow schema (takes precedence over fields_metadata)
|
|
280
|
+
:param batch_size: Number of records to process in each batch
|
|
281
|
+
:param convert_empty_to_null: Convert empty strings to null values
|
|
282
|
+
"""
|
|
283
|
+
effective_schema = None
|
|
284
|
+
if schema:
|
|
285
|
+
effective_schema = schema
|
|
286
|
+
elif fields_metadata:
|
|
287
|
+
effective_schema = create_schema_from_metadata(fields_metadata)
|
|
288
|
+
|
|
289
|
+
writer = ParquetWriter(
|
|
290
|
+
file_path=file_path,
|
|
291
|
+
schema=effective_schema,
|
|
292
|
+
batch_size=batch_size,
|
|
293
|
+
convert_empty_to_null=convert_empty_to_null,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
await writer.write_query_result_async(query_result)
|