ygg 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,745 @@
1
+ import base64
2
+ import io
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from typing import TYPE_CHECKING, Optional, IO, AnyStr, Union
6
+
7
+ import pyarrow as pa
8
+ import pyarrow.csv as pcsv
9
+ import pyarrow.parquet as pq
10
+ from pyarrow.dataset import FileFormat, ParquetFileFormat, CsvFileFormat
11
+
12
+ from .path_kind import DatabricksPathKind
13
+ from ...libs.databrickslib import databricks
14
+ from ...types.cast.pandas_cast import PandasDataFrame
15
+ from ...types.cast.polars_pandas_cast import PolarsDataFrame
16
+ from ...types.cast.registry import convert
17
+
18
+ if databricks is not None:
19
+ from databricks.sdk.service.workspace import ImportFormat, ExportFormat
20
+ from databricks.sdk.errors.platform import (
21
+ NotFound,
22
+ ResourceDoesNotExist,
23
+ BadRequest,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ from .path import DatabricksPath
28
+
29
+
30
+ __all__ = [
31
+ "DatabricksIO"
32
+ ]
33
+
34
+
35
+ class DatabricksIO(ABC, IO):
36
+
37
+ def __init__(
38
+ self,
39
+ path: "DatabricksPath",
40
+ mode: str,
41
+ encoding: Optional[str] = None,
42
+ compression: Optional[str] = "detect",
43
+ position: int = 0,
44
+ buffer: Optional[io.BytesIO] = None,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.encoding = encoding
49
+ self.mode = mode
50
+ self.compression = compression
51
+
52
+ self.path = path
53
+
54
+ self.buffer = buffer
55
+ self.position = position
56
+
57
+ self._write_flag = False
58
+
59
+ def __enter__(self) -> "DatabricksIO":
60
+ return self.connect(clone=False)
61
+
62
+ def __exit__(self, exc_type, exc_value, traceback):
63
+ self.close()
64
+
65
+ def __del__(self):
66
+ self.close()
67
+
68
+ def __next__(self):
69
+ line = self.readline()
70
+ if not line:
71
+ raise StopIteration
72
+ return line
73
+
74
+ def __iter__(self):
75
+ return self
76
+
77
+ def __hash__(self):
78
+ return self.path.__hash__()
79
+
80
+ @classmethod
81
+ def create_instance(
82
+ cls,
83
+ path: "DatabricksPath",
84
+ mode: str,
85
+ encoding: Optional[str] = None,
86
+ compression: Optional[str] = "detect",
87
+ position: int = 0,
88
+ buffer: Optional[io.BytesIO] = None,
89
+ ) -> "DatabricksIO":
90
+ if path.kind == DatabricksPathKind.VOLUME:
91
+ return DatabricksVolumeIO(
92
+ path=path,
93
+ mode=mode,
94
+ encoding=encoding,
95
+ compression=compression,
96
+ position=position,
97
+ buffer=buffer,
98
+ )
99
+ elif path.kind == DatabricksPathKind.DBFS:
100
+ return DatabricksDBFSIO(
101
+ path=path,
102
+ mode=mode,
103
+ encoding=encoding,
104
+ compression=compression,
105
+ position=position,
106
+ buffer=buffer,
107
+ )
108
+ elif path.kind == DatabricksPathKind.WORKSPACE:
109
+ return DatabricksWorkspaceIO(
110
+ path=path,
111
+ mode=mode,
112
+ encoding=encoding,
113
+ compression=compression,
114
+ position=position,
115
+ buffer=buffer,
116
+ )
117
+ else:
118
+ raise ValueError(f"Unsupported DatabricksPath kind: {path.kind}")
119
+
120
+ @property
121
+ def workspace(self):
122
+ return self.path.workspace
123
+
124
+ @property
125
+ def name(self):
126
+ return self.path.name
127
+
128
+ @property
129
+ def mode(self):
130
+ return self._mode
131
+
132
+ @mode.setter
133
+ def mode(self, value: str):
134
+ self._mode = value
135
+
136
+ # Basic text/binary behavior:
137
+ # - binary -> encoding None
138
+ # - text -> default utf-8
139
+ if "b" in self._mode:
140
+ self.encoding = None
141
+ else:
142
+ if self.encoding is None:
143
+ self.encoding = "utf-8"
144
+
145
+ @property
146
+ def content_length(self) -> int:
147
+ return self.path.content_length
148
+
149
+ def size(self):
150
+ return self.content_length
151
+
152
+ @content_length.setter
153
+ def content_length(self, value: int):
154
+ self.path.content_length = value
155
+
156
+ @property
157
+ def buffer(self):
158
+ if self._buffer is None:
159
+ self._buffer = io.BytesIO()
160
+ self._buffer.seek(self.position, io.SEEK_SET)
161
+ return self._buffer
162
+
163
+ @buffer.setter
164
+ def buffer(self, value: Optional[io.BytesIO]):
165
+ self._buffer = value
166
+
167
+ def clear_buffer(self):
168
+ if self._buffer is not None:
169
+ self._buffer.close()
170
+ self._buffer = None
171
+
172
+ def clone_instance(self, **kwargs):
173
+ return self.__class__(
174
+ path=kwargs.get("path", self.path),
175
+ mode=kwargs.get("mode", self.mode),
176
+ encoding=kwargs.get("encoding", self.encoding),
177
+ compression=kwargs.get("compression", self.compression),
178
+ position=kwargs.get("position", self.position),
179
+ buffer=kwargs.get("buffer", self._buffer),
180
+ )
181
+
182
+ @property
183
+ def connected(self):
184
+ return self.path.connected
185
+
186
+ def connect(self, clone: bool = False) -> "DatabricksIO":
187
+ path = self.path.connect(clone=clone)
188
+
189
+ if clone:
190
+ return self.clone_instance(path=path)
191
+
192
+ self.path = path
193
+ return self
194
+
195
+ def close(self):
196
+ self.flush()
197
+ if self._buffer is not None:
198
+ self._buffer.close()
199
+
200
+ def fileno(self):
201
+ return hash(self)
202
+
203
+ def isatty(self):
204
+ return False
205
+
206
+ def tell(self):
207
+ return self.position
208
+
209
+ def seekable(self):
210
+ return True
211
+
212
+ def seek(self, offset, whence=0, /):
213
+ if whence == io.SEEK_SET:
214
+ new_position = offset
215
+ elif whence == io.SEEK_CUR:
216
+ new_position = self.position + offset
217
+ elif whence == io.SEEK_END:
218
+ end_position = self.content_length
219
+ new_position = end_position + offset
220
+ else:
221
+ raise ValueError("Invalid value for whence")
222
+
223
+ if new_position < 0:
224
+ raise ValueError("New position is before the start of the file")
225
+
226
+ if self._buffer is not None:
227
+ self._buffer.seek(new_position, io.SEEK_SET)
228
+
229
+ self.position = new_position
230
+ return self.position
231
+
232
+ def readable(self):
233
+ return True
234
+
235
+ def getvalue(self):
236
+ if self._buffer is not None:
237
+ return self._buffer.getvalue()
238
+ return self.read_all_bytes()
239
+
240
+ def getbuffer(self):
241
+ return self.buffer
242
+
243
+ @abstractmethod
244
+ def read_byte_range(self, start: int, length: int, allow_not_found: bool = False) -> bytes:
245
+ pass
246
+
247
+ def read_all_bytes(self, use_cache: bool = True, allow_not_found: bool = False) -> bytes:
248
+ if use_cache and self._buffer is not None:
249
+ buffer_value = self._buffer.getvalue()
250
+
251
+ if len(buffer_value) == self.content_length:
252
+ return buffer_value
253
+
254
+ self._buffer.close()
255
+ self._buffer = None
256
+
257
+ data = self.read_byte_range(0, self.content_length, allow_not_found=allow_not_found)
258
+
259
+ # Keep size accurate even if backend didn't know it
260
+ self.content_length = len(data)
261
+
262
+ if use_cache and self._buffer is None:
263
+ self._buffer = io.BytesIO(data)
264
+ self._buffer.seek(self.position, io.SEEK_SET)
265
+
266
+ return data
267
+
268
+ def read(self, n=-1, use_cache: bool = True):
269
+ if not self.readable():
270
+ raise IOError("File not open for reading")
271
+
272
+ current_position = self.position
273
+ all_data = self.read_all_bytes(use_cache=use_cache)
274
+
275
+ if n == -1:
276
+ n = self.content_length - current_position
277
+
278
+ data = all_data[current_position:current_position + n]
279
+ read_length = len(data)
280
+
281
+ self.position += read_length
282
+
283
+ if self.encoding:
284
+ return data.decode(self.encoding)
285
+ return data
286
+
287
+ def readline(self, limit=-1, use_cache: bool = True):
288
+ if not self.readable():
289
+ raise IOError("File not open for reading")
290
+
291
+ if self.encoding:
292
+ # Text-mode: accumulate characters
293
+ out_chars = []
294
+ read_chars = 0
295
+
296
+ while limit == -1 or read_chars < limit:
297
+ ch = self.read(1, use_cache=use_cache)
298
+ if not ch:
299
+ break
300
+ out_chars.append(ch)
301
+ read_chars += 1
302
+ if ch == "\n":
303
+ break
304
+
305
+ return "".join(out_chars)
306
+
307
+ # Binary-mode: accumulate bytes
308
+ line_bytes = bytearray()
309
+ bytes_read = 0
310
+
311
+ while limit == -1 or bytes_read < limit:
312
+ b = self.read(1, use_cache=use_cache)
313
+ if not b:
314
+ break
315
+ line_bytes.extend(b)
316
+ bytes_read += 1
317
+ if b == b"\n":
318
+ break
319
+
320
+ return bytes(line_bytes)
321
+
322
+ def readlines(self, hint=-1, use_cache: bool = True):
323
+ if not self.readable():
324
+ raise IOError("File not open for reading")
325
+
326
+ lines = []
327
+ total = 0
328
+
329
+ while True:
330
+ line = self.readline(use_cache=use_cache)
331
+ if not line:
332
+ break
333
+ lines.append(line)
334
+ total += len(line)
335
+ if hint != -1 and total >= hint:
336
+ break
337
+
338
+ return lines
339
+
340
+ def appendable(self):
341
+ return "a" in self.mode
342
+
343
+ def writable(self):
344
+ return True
345
+
346
+ @abstractmethod
347
+ def write_all_bytes(self, data: bytes):
348
+ pass
349
+
350
+ def truncate(self, size=None, /):
351
+ if size is None:
352
+ size = self.position
353
+
354
+ if self._buffer is not None:
355
+ self._buffer.truncate(size)
356
+ else:
357
+ data = b"\x00" * size
358
+ self.write_all_bytes(data=data)
359
+
360
+ self.content_length = size
361
+ self._write_flag = True
362
+ return size
363
+
364
+ def flush(self):
365
+ if self._write_flag and self._buffer is not None:
366
+ self.write_all_bytes(data=self._buffer.getvalue())
367
+ self._write_flag = False
368
+
369
+ def write(self, data: AnyStr) -> int:
370
+ if not self.writable():
371
+ raise IOError("File not open for writing")
372
+
373
+ if isinstance(data, str):
374
+ data = data.encode(self.encoding or "utf-8")
375
+
376
+ written = self.buffer.write(data)
377
+
378
+ self.position += written
379
+ self.content_length = self.position
380
+ self._write_flag = True
381
+
382
+ return written
383
+
384
+ def writelines(self, lines) -> None:
385
+ for line in lines:
386
+ if isinstance(line, str):
387
+ line = line.encode(self.encoding or "utf-8")
388
+ elif not isinstance(line, (bytes, bytearray)):
389
+ raise TypeError(
390
+ "a bytes-like or str object is required, not '{}'".format(type(line).__name__)
391
+ )
392
+
393
+ data = line + b"\n" if not line.endswith(b"\n") else line
394
+ self.write(data)
395
+
396
+ def get_output_stream(self, *args, **kwargs):
397
+ return self
398
+
399
+ def copy_to(
400
+ self,
401
+ dest: Union["DatabricksIO", "DatabricksPath", str]
402
+ ) -> None:
403
+ if not isinstance(dest, DatabricksIO):
404
+ from .path import DatabricksPath
405
+
406
+ dest_path = DatabricksPath.parse(dest, workspace=self.workspace)
407
+
408
+ with dest_path.open(mode="wb") as d:
409
+ return self.copy_to(dest=d)
410
+
411
+ dest.write_all_bytes(data=self.read_all_bytes(use_cache=False))
412
+
413
+ # ---- format helpers ----
414
+
415
+ def _reset_for_write(self):
416
+ if self._buffer is not None:
417
+ self._buffer.seek(0, io.SEEK_SET)
418
+ self._buffer.truncate(0)
419
+
420
+ self.position = 0
421
+ self.content_length = 0
422
+ self._write_flag = True
423
+
424
+ # ---- Data Querying Helpers ----
425
+
426
+ def write_table(
427
+ self,
428
+ table: Union[pa.Table, pa.RecordBatch, PolarsDataFrame, PandasDataFrame],
429
+ batch_size: Optional[int] = None,
430
+ **kwargs
431
+ ):
432
+ if isinstance(table, pa.Table):
433
+ return self.write_arrow_table(table, batch_size=batch_size, **kwargs)
434
+ elif isinstance(table, pa.RecordBatch):
435
+ return self.write_arrow_batch(table, batch_size=batch_size, **kwargs)
436
+ elif isinstance(table, PolarsDataFrame):
437
+ return self.write_polars(table, batch_size=batch_size, **kwargs)
438
+ elif isinstance(table, PandasDataFrame):
439
+ return self.write_pandas(table, batch_size=batch_size, **kwargs)
440
+ else:
441
+ raise ValueError(f"Cannot write {type(table)} to {self.path}")
442
+
443
+ # ---- Arrow ----
444
+
445
+ def read_arrow_table(
446
+ self,
447
+ file_format: Optional[FileFormat] = None,
448
+ batch_size: Optional[int] = None,
449
+ **kwargs
450
+ ) -> pa.Table:
451
+ file_format = self.path.file_format if file_format is None else file_format
452
+ self.seek(0)
453
+
454
+ if isinstance(file_format, ParquetFileFormat):
455
+ return pq.read_table(self, **kwargs)
456
+
457
+ if isinstance(file_format, CsvFileFormat):
458
+ return pcsv.read_csv(self, parse_options=file_format.parse_options)
459
+
460
+ raise ValueError(f"Unsupported file format for Arrow table: {file_format}")
461
+
462
+ def write_arrow(
463
+ self,
464
+ table: Union[pa.Table, pa.RecordBatch],
465
+ batch_size: Optional[int] = None,
466
+ **kwargs
467
+ ):
468
+ if not isinstance(table, pa.Table):
469
+ table = convert(table, pa.Table)
470
+
471
+ return self.write_arrow_table(
472
+ table=table,
473
+ batch_size=batch_size,
474
+ **kwargs
475
+ )
476
+
477
+ def write_arrow_table(
478
+ self,
479
+ table: pa.Table,
480
+ file_format: Optional[FileFormat] = None,
481
+ batch_size: Optional[int] = None,
482
+ **kwargs
483
+ ):
484
+ file_format = self.path.file_format if file_format is None else file_format
485
+ buffer = io.BytesIO()
486
+
487
+ if isinstance(file_format, ParquetFileFormat):
488
+ pq.write_table(table, buffer, **kwargs)
489
+
490
+ elif isinstance(file_format, CsvFileFormat):
491
+ pcsv.write_csv(table, buffer, **kwargs)
492
+
493
+ else:
494
+ raise ValueError(f"Unsupported file format for Arrow table: {file_format}")
495
+
496
+ self.write_all_bytes(data=buffer.getvalue())
497
+
498
+ def write_arrow_batch(
499
+ self,
500
+ batch: pa.RecordBatch,
501
+ batch_size: Optional[int] = None,
502
+ **kwargs
503
+ ):
504
+ table = pa.Table.from_batches([batch])
505
+ self.write_arrow_table(table, batch_size=batch_size, **kwargs)
506
+
507
+ def read_arrow_batches(
508
+ self,
509
+ batch_size: Optional[int] = None,
510
+ **kwargs
511
+ ):
512
+ return (
513
+ self
514
+ .read_arrow_table(batch_size=batch_size, **kwargs)
515
+ .to_batches(max_chunksize=batch_size)
516
+ )
517
+
518
+ # ---- Pandas ----
519
+
520
+ def read_pandas(
521
+ self,
522
+ batch_size: Optional[int] = None,
523
+ **kwargs
524
+ ):
525
+ return self.read_arrow_table(batch_size=batch_size, **kwargs).to_pandas()
526
+
527
+ def write_pandas(
528
+ self,
529
+ df,
530
+ batch_size: Optional[int] = None,
531
+ **kwargs
532
+ ):
533
+ self.write_arrow_table(pa.table(df), batch_size=batch_size, **kwargs)
534
+
535
+ # ---- Polars ----
536
+
537
+ def read_polars(
538
+ self,
539
+ file_format: Optional[FileFormat] = None,
540
+ batch_size: Optional[int] = None,
541
+ **kwargs
542
+ ):
543
+ import polars as pl
544
+
545
+ file_format = self.path.file_format if file_format is None else file_format
546
+ self.seek(0)
547
+
548
+ if isinstance(file_format, ParquetFileFormat):
549
+ return pl.read_parquet(self, **kwargs)
550
+
551
+ if isinstance(file_format, CsvFileFormat):
552
+ return pl.read_csv(self, **kwargs)
553
+
554
+ raise ValueError(f"Unsupported file format for Polars DataFrame: {file_format}")
555
+
556
+ def write_polars(
557
+ self,
558
+ df,
559
+ file_format: Optional[FileFormat] = None,
560
+ batch_size: Optional[int] = None,
561
+ **kwargs
562
+ ):
563
+ file_format = self.path.file_format if file_format is None else FileFormat
564
+ self._reset_for_write()
565
+
566
+ if isinstance(file_format, ParquetFileFormat):
567
+ df.write_parquet(self, **kwargs)
568
+
569
+ elif isinstance(file_format, CsvFileFormat):
570
+ df.write_csv(self, **kwargs)
571
+
572
+ else:
573
+ raise ValueError(f"Unsupported file format for Polars DataFrame: {file_format}")
574
+
575
+
576
+ class DatabricksWorkspaceIO(DatabricksIO):
577
+
578
+ def read_byte_range(self, start: int, length: int, allow_not_found: bool = False) -> bytes:
579
+ if length == 0:
580
+ return b""
581
+
582
+ sdk = self.workspace.sdk()
583
+ client = sdk.workspace
584
+ full_path = self.path.workspace_full_path()
585
+
586
+ result = client.download(
587
+ path=full_path,
588
+ format=ExportFormat.AUTO,
589
+ )
590
+
591
+ if result is None:
592
+ return b""
593
+
594
+ data = result.read()
595
+
596
+ end = start + length
597
+ return data[start:end]
598
+
599
+ def write_all_bytes(self, data: bytes):
600
+ sdk = self.workspace.sdk()
601
+ workspace_client = sdk.workspace
602
+ full_path = self.path.workspace_full_path()
603
+
604
+ try:
605
+ workspace_client.upload(
606
+ full_path,
607
+ data,
608
+ format=ImportFormat.AUTO,
609
+ overwrite=True
610
+ )
611
+ except (NotFound, ResourceDoesNotExist, BadRequest):
612
+ self.path.parent.make_workspace_dir(parents=True)
613
+
614
+ workspace_client.upload(
615
+ full_path,
616
+ data,
617
+ format=ImportFormat.AUTO,
618
+ overwrite=True
619
+ )
620
+
621
+ self.path.reset_metadata(
622
+ is_file=True,
623
+ is_dir=False,
624
+ size=len(data),
625
+ mtime=time.time()
626
+ )
627
+
628
+ return self
629
+
630
+
631
+ class DatabricksVolumeIO(DatabricksIO):
632
+
633
+ def read_byte_range(self, start: int, length: int, allow_not_found: bool = False) -> bytes:
634
+ if length == 0:
635
+ return b""
636
+
637
+ sdk = self.workspace.sdk()
638
+ client = sdk.files
639
+ full_path = self.path.files_full_path()
640
+
641
+ resp = client.download(full_path)
642
+ result = (
643
+ resp.contents
644
+ .seek(start, io.SEEK_SET)
645
+ .read(length)
646
+ )
647
+
648
+ return result
649
+
650
+ def write_all_bytes(self, data: bytes):
651
+ sdk = self.workspace.sdk()
652
+ client = sdk.files
653
+ full_path = self.path.files_full_path()
654
+
655
+ try:
656
+ client.upload(
657
+ full_path,
658
+ io.BytesIO(data),
659
+ overwrite=True
660
+ )
661
+ except (NotFound, ResourceDoesNotExist, BadRequest):
662
+ self.path.parent.mkdir(parents=True, exist_ok=True)
663
+
664
+ client.upload(
665
+ full_path,
666
+ io.BytesIO(data),
667
+ overwrite=True
668
+ )
669
+
670
+ self.path.reset_metadata(
671
+ is_file=True,
672
+ is_dir=False,
673
+ size=len(data),
674
+ mtime=time.time()
675
+ )
676
+
677
+ return self
678
+
679
+
680
+ class DatabricksDBFSIO(DatabricksIO):
681
+
682
+ def read_byte_range(self, start: int, length: int, allow_not_found: bool = False) -> bytes:
683
+ if length == 0:
684
+ return b""
685
+
686
+ sdk = self.workspace.sdk()
687
+ client = sdk.dbfs
688
+ full_path = self.path.dbfs_full_path()
689
+
690
+ read_bytes = bytearray()
691
+ bytes_to_read = length
692
+ current_position = start
693
+
694
+ while bytes_to_read > 0:
695
+ chunk_size = min(bytes_to_read, 2 * 1024 * 1024)
696
+
697
+ resp = client.read(
698
+ path=full_path,
699
+ offset=current_position,
700
+ length=chunk_size
701
+ )
702
+
703
+ if not resp.data:
704
+ break
705
+
706
+ # resp.data is base64; decode and move offsets by *decoded* length
707
+ resp_data_bytes = base64.b64decode(resp.data)
708
+
709
+ read_bytes.extend(resp_data_bytes)
710
+ bytes_read = len(resp_data_bytes) # <-- FIX (was base64 string length)
711
+ current_position += bytes_read
712
+ bytes_to_read -= bytes_read
713
+
714
+ return bytes(read_bytes)
715
+
716
+ def write_all_bytes(self, data: bytes):
717
+ sdk = self.workspace.sdk()
718
+ client = sdk.dbfs
719
+ full_path = self.path.dbfs_full_path()
720
+
721
+ try:
722
+ with client.open(
723
+ path=full_path,
724
+ read=False,
725
+ write=True,
726
+ overwrite=True
727
+ ) as f:
728
+ f.write(data)
729
+ except (NotFound, ResourceDoesNotExist, BadRequest):
730
+ self.path.parent.mkdir(parents=True, exist_ok=True)
731
+
732
+ with client.open(
733
+ path=full_path,
734
+ read=False,
735
+ write=True,
736
+ overwrite=True
737
+ ) as f:
738
+ f.write(data)
739
+
740
+ self.path.reset_metadata(
741
+ is_file=True,
742
+ is_dir=False,
743
+ size=len(data),
744
+ mtime=time.time()
745
+ )