singlestoredb 1.11.0__cp38-abi3-win32.whl → 1.12.1__cp38-abi3-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  #!/usr/bin/env python3
2
+ import dataclasses
2
3
  import datetime
3
4
  import inspect
4
5
  import numbers
@@ -22,6 +23,12 @@ try:
22
23
  except ImportError:
23
24
  has_numpy = False
24
25
 
26
+ try:
27
+ import pydantic
28
+ has_pydantic = True
29
+ except ImportError:
30
+ has_pydantic = False
31
+
25
32
  from . import dtypes as dt
26
33
  from ..mysql.converters import escape_item # type: ignore
27
34
 
@@ -243,6 +250,9 @@ def classify_dtype(dtype: Any) -> str:
243
250
  if isinstance(dtype, list):
244
251
  return '|'.join(classify_dtype(x) for x in dtype)
245
252
 
253
+ if isinstance(dtype, str):
254
+ return sql_to_dtype(dtype)
255
+
246
256
  # Specific types
247
257
  if dtype is None or dtype is type(None): # noqa: E721
248
258
  return 'null'
@@ -253,6 +263,21 @@ def classify_dtype(dtype: Any) -> str:
253
263
  if dtype is bool:
254
264
  return 'bool'
255
265
 
266
+ if dataclasses.is_dataclass(dtype):
267
+ fields = dataclasses.fields(dtype)
268
+ item_dtypes = ','.join(
269
+ f'{classify_dtype(simplify_dtype(x.type))}' for x in fields
270
+ )
271
+ return f'tuple[{item_dtypes}]'
272
+
273
+ if has_pydantic and inspect.isclass(dtype) and issubclass(dtype, pydantic.BaseModel):
274
+ fields = dtype.model_fields.values()
275
+ item_dtypes = ','.join(
276
+ f'{classify_dtype(simplify_dtype(x.annotation))}' # type: ignore
277
+ for x in fields
278
+ )
279
+ return f'tuple[{item_dtypes}]'
280
+
256
281
  if not inspect.isclass(dtype):
257
282
  # Check for compound types
258
283
  origin = typing.get_origin(dtype)
@@ -261,7 +286,7 @@ def classify_dtype(dtype: Any) -> str:
261
286
  if origin is Tuple:
262
287
  args = typing.get_args(dtype)
263
288
  item_dtypes = ','.join(classify_dtype(x) for x in args)
264
- return f'tuple:{item_dtypes}'
289
+ return f'tuple[{item_dtypes}]'
265
290
 
266
291
  # Array types
267
292
  elif issubclass(origin, array_types):
@@ -312,7 +337,10 @@ def classify_dtype(dtype: Any) -> str:
312
337
  if is_int:
313
338
  return int_type_map.get(name, 'int64')
314
339
 
315
- raise TypeError(f'unsupported type annotation: {dtype}')
340
+ raise TypeError(
341
+ f'unsupported type annotation: {dtype}; '
342
+ 'use `args`/`returns` on the @udf/@tvf decotator to specify the data type',
343
+ )
316
344
 
317
345
 
318
346
  def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
@@ -428,6 +456,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
428
456
  args: List[Dict[str, Any]] = []
429
457
  attrs = getattr(func, '_singlestoredb_attrs', {})
430
458
  name = attrs.get('name', name if name else func.__name__)
459
+ function_type = attrs.get('function_type', 'udf')
431
460
  out: Dict[str, Any] = dict(name=name, args=args)
432
461
 
433
462
  arg_names = [x for x in signature.parameters]
@@ -448,6 +477,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
448
477
 
449
478
  args_overrides = attrs.get('args', None)
450
479
  returns_overrides = attrs.get('returns', None)
480
+ output_fields = attrs.get('output_fields', None)
451
481
 
452
482
  spec_diff = set(arg_names).difference(set(annotations.keys()))
453
483
 
@@ -488,7 +518,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
488
518
  arg_type = collapse_dtypes([
489
519
  classify_dtype(x) for x in simplify_dtype(annotations[arg])
490
520
  ])
491
- sql = dtype_to_sql(arg_type)
521
+ sql = dtype_to_sql(arg_type, function_type=function_type)
492
522
  args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i]))
493
523
 
494
524
  if returns_overrides is None \
@@ -498,13 +528,56 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
498
528
  if isinstance(returns_overrides, str):
499
529
  sql = returns_overrides
500
530
  out_type = sql_to_dtype(sql)
531
+ elif isinstance(returns_overrides, list):
532
+ if not output_fields:
533
+ output_fields = [
534
+ string.ascii_letters[i] for i in range(len(returns_overrides))
535
+ ]
536
+ out_type = 'tuple[' + collapse_dtypes([
537
+ classify_dtype(x)
538
+ for x in simplify_dtype(returns_overrides)
539
+ ]).replace('|', ',') + ']'
540
+ sql = dtype_to_sql(
541
+ out_type, function_type=function_type, field_names=output_fields,
542
+ )
543
+ elif dataclasses.is_dataclass(returns_overrides):
544
+ out_type = collapse_dtypes([
545
+ classify_dtype(x)
546
+ for x in simplify_dtype([x.type for x in returns_overrides.fields])
547
+ ])
548
+ sql = dtype_to_sql(
549
+ out_type,
550
+ function_type=function_type,
551
+ field_names=[x.name for x in returns_overrides.fields],
552
+ )
553
+ elif has_pydantic and inspect.isclass(returns_overrides) \
554
+ and issubclass(returns_overrides, pydantic.BaseModel):
555
+ out_type = collapse_dtypes([
556
+ classify_dtype(x)
557
+ for x in simplify_dtype([x for x in returns_overrides.model_fields.values()])
558
+ ])
559
+ sql = dtype_to_sql(
560
+ out_type,
561
+ function_type=function_type,
562
+ field_names=[x for x in returns_overrides.model_fields.keys()],
563
+ )
501
564
  elif returns_overrides is not None and not isinstance(returns_overrides, str):
502
565
  raise TypeError(f'unrecognized type for return value: {returns_overrides}')
503
566
  else:
567
+ if not output_fields:
568
+ if dataclasses.is_dataclass(signature.return_annotation):
569
+ output_fields = [
570
+ x.name for x in dataclasses.fields(signature.return_annotation)
571
+ ]
572
+ elif has_pydantic and inspect.isclass(signature.return_annotation) \
573
+ and issubclass(signature.return_annotation, pydantic.BaseModel):
574
+ output_fields = list(signature.return_annotation.model_fields.keys())
504
575
  out_type = collapse_dtypes([
505
576
  classify_dtype(x) for x in simplify_dtype(signature.return_annotation)
506
577
  ])
507
- sql = dtype_to_sql(out_type)
578
+ sql = dtype_to_sql(
579
+ out_type, function_type=function_type, field_names=output_fields,
580
+ )
508
581
  out['returns'] = dict(dtype=out_type, sql=sql, default=None)
509
582
 
510
583
  copied_keys = ['database', 'environment', 'packages', 'resources', 'replace']
@@ -559,7 +632,12 @@ def sql_to_dtype(sql: str) -> str:
559
632
  return dtype
560
633
 
561
634
 
562
- def dtype_to_sql(dtype: str, default: Any = None) -> str:
635
+ def dtype_to_sql(
636
+ dtype: str,
637
+ default: Any = None,
638
+ field_names: Optional[List[str]] = None,
639
+ function_type: str = 'udf',
640
+ ) -> str:
563
641
  """
564
642
  Convert a collapsed dtype string to a SQL type.
565
643
 
@@ -569,6 +647,8 @@ def dtype_to_sql(dtype: str, default: Any = None) -> str:
569
647
  Simplified data type string
570
648
  default : Any, optional
571
649
  Default value
650
+ field_names : List[str], optional
651
+ Field names for tuple types
572
652
 
573
653
  Returns
574
654
  -------
@@ -592,7 +672,7 @@ def dtype_to_sql(dtype: str, default: Any = None) -> str:
592
672
  if dtype.startswith('array['):
593
673
  _, dtypes = dtype.split('[', 1)
594
674
  dtypes = dtypes[:-1]
595
- item_dtype = dtype_to_sql(dtypes)
675
+ item_dtype = dtype_to_sql(dtypes, function_type=function_type)
596
676
  return f'ARRAY({item_dtype}){nullable}{default_clause}'
597
677
 
598
678
  if dtype.startswith('tuple['):
@@ -600,11 +680,22 @@ def dtype_to_sql(dtype: str, default: Any = None) -> str:
600
680
  dtypes = dtypes[:-1]
601
681
  item_dtypes = []
602
682
  for i, item in enumerate(dtypes.split(',')):
603
- name = string.ascii_letters[i]
683
+ if field_names:
684
+ name = field_names[i]
685
+ else:
686
+ name = string.ascii_letters[i]
604
687
  if '=' in item:
605
688
  name, item = item.split('=', 1)
606
- item_dtypes.append(name + ' ' + dtype_to_sql(item))
607
- return f'RECORD({", ".join(item_dtypes)}){nullable}{default_clause}'
689
+ item_dtypes.append(
690
+ f'`{name}` ' + dtype_to_sql(item, function_type=function_type),
691
+ )
692
+ if function_type == 'udf':
693
+ return f'RECORD({", ".join(item_dtypes)}){nullable}{default_clause}'
694
+ else:
695
+ return re.sub(
696
+ r' NOT NULL\s*$', r'',
697
+ f'TABLE({", ".join(item_dtypes)}){nullable}{default_clause}',
698
+ )
608
699
 
609
700
  return f'{sql_type_map[dtype]}{nullable}{default_clause}'
610
701
 
@@ -43,8 +43,6 @@ class CreateClusterIdentity(SQLHandler):
43
43
 
44
44
  Remarks
45
45
  -------
46
- * ``FROM <table>`` specifies the SingleStore table to export. The same name will
47
- be used for the exported table.
48
46
  * ``CATALOG`` specifies the details of the catalog to connect to.
49
47
  * ``LINK`` specifies the details of the data storage to connect to.
50
48
 
@@ -69,6 +67,8 @@ class CreateClusterIdentity(SQLHandler):
69
67
 
70
68
  """
71
69
 
70
+ _enabled = False
71
+
72
72
  def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
73
73
  # Catalog
74
74
  catalog_config = json.loads(params['catalog'].get('catalog_config', '{}') or '{}')
@@ -110,11 +110,34 @@ class CreateExport(SQLHandler):
110
110
  from_table
111
111
  catalog
112
112
  storage
113
+ [ partition_by ]
114
+ [ order_by ]
115
+ [ properties ]
113
116
  ;
114
117
 
115
118
  # From table
116
119
  from_table = FROM <table>
117
120
 
121
+ # Transforms
122
+ _col_transform = { VOID | IDENTITY | YEAR | MONTH | DAY | HOUR } ( _transform_col )
123
+ _transform_col = <column>
124
+ _arg_transform = { BUCKET | TRUNCATE } ( _transform_col <comma> _transform_arg )
125
+ _transform_arg = <integer>
126
+ transform = { _col_transform | _arg_transform }
127
+
128
+ # Partitions
129
+ partition_by = PARTITION BY partition_key,...
130
+ partition_key = transform
131
+
132
+ # Sort order
133
+ order_by = ORDER BY sort_key,...
134
+ sort_key = transform [ direction ] [ null_order ]
135
+ direction = { ASC | DESC | ASCENDING | DESCENDING }
136
+ null_order = { NULLS_FIRST | NULLS_LAST }
137
+
138
+ # Properties
139
+ properties = PROPERTIES '<json>'
140
+
118
141
  # Catolog
119
142
  catalog = CATALOG [ _catalog_config ] [ _catalog_creds ]
120
143
  _catalog_config = CONFIG '<catalog-config>'
@@ -163,6 +186,8 @@ class CreateExport(SQLHandler):
163
186
 
164
187
  """ # noqa
165
188
 
189
+ _enabled = False
190
+
166
191
  def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
167
192
  # From table
168
193
  if isinstance(params['from_table'], str):
@@ -189,6 +214,32 @@ class CreateExport(SQLHandler):
189
214
  if wsg._manager is None:
190
215
  raise TypeError('no workspace manager is associated with workspace group')
191
216
 
217
+ partition_by = []
218
+ if params['partition_by']:
219
+ for key in params['partition_by']:
220
+ transform = key['partition_key']['transform']['col_transform']
221
+ part = {}
222
+ part['transform'] = transform[0].lower()
223
+ part['name'] = transform[-1]['transform_col']
224
+ partition_by.append(part)
225
+
226
+ order_by = []
227
+ if params['order_by'] and params['order_by']['by']:
228
+ for key in params['order_by']['by']:
229
+ transform = key['transform']['col_transform']
230
+ order = {}
231
+ order['transform'] = transform[0].lower()
232
+ order['name'] = transform[-1]['transform_col']
233
+ order['direction'] = 'ascending'
234
+ order['null_order'] = 'nulls_first'
235
+ if key.get('direction'):
236
+ if 'desc' in key['direction'].lower():
237
+ order['direction'] = 'descending'
238
+ if key.get('null_order'):
239
+ if 'last' in key['null_order'].lower():
240
+ order['null_order'] = 'nulls_last'
241
+ order_by.append(order)
242
+
192
243
  out = ExportService(
193
244
  wsg,
194
245
  from_database,
@@ -196,6 +247,9 @@ class CreateExport(SQLHandler):
196
247
  dict(**catalog_config, **catalog_creds),
197
248
  dict(**storage_config, **storage_creds),
198
249
  columns=None,
250
+ partition_by=partition_by or None,
251
+ order_by=order_by or None,
252
+ properties=json.loads(params['properties']) if params['properties'] else None,
199
253
  ).start()
200
254
 
201
255
  res = FusionSQLResult()
@@ -217,6 +271,8 @@ class ShowExport(SQLHandler):
217
271
 
218
272
  """
219
273
 
274
+ _enabled = False
275
+
220
276
  def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
221
277
  wsg = get_workspace_group({})
222
278
  out = ExportStatus(params['export_id'], wsg)
@@ -209,7 +209,7 @@ class UploadPersonalFileHandler(UploadFileHandler):
209
209
  FROM local_path [ overwrite ];
210
210
 
211
211
  # Path to file
212
- path = '<path>'
212
+ path = '<filename>'
213
213
 
214
214
  # Path to local file
215
215
  local_path = '<local-path>'
@@ -223,7 +223,7 @@ class UploadPersonalFileHandler(UploadFileHandler):
223
223
 
224
224
  Arguments
225
225
  ---------
226
- * ``<path>``: The path in the personal/shared space where the file is uploaded.
226
+ * ``<filename>``: The filename in the personal/shared space where the file is uploaded.
227
227
  * ``<local-path>``: The path to the file to upload in the local
228
228
  directory.
229
229
 
@@ -237,7 +237,7 @@ class UploadPersonalFileHandler(UploadFileHandler):
237
237
  The following command uploads a file to a personal/shared space and overwrite any
238
238
  existing files at the specified path::
239
239
 
240
- UPLOAD PERSONAL FILE TO '/data/stats.csv'
240
+ UPLOAD PERSONAL FILE TO 'stats.csv'
241
241
  FROM '/tmp/user/stats.csv' OVERWRITE;
242
242
 
243
243
  See Also
@@ -259,7 +259,7 @@ class UploadSharedFileHandler(UploadFileHandler):
259
259
  FROM local_path [ overwrite ];
260
260
 
261
261
  # Path to file
262
- path = '<path>'
262
+ path = '<filename>'
263
263
 
264
264
  # Path to local file
265
265
  local_path = '<local-path>'
@@ -273,7 +273,7 @@ class UploadSharedFileHandler(UploadFileHandler):
273
273
 
274
274
  Arguments
275
275
  ---------
276
- * ``<path>``: The path in the personal/shared space where the file is uploaded.
276
+ * ``<filename>``: The filename in the personal/shared space where the file is uploaded.
277
277
  * ``<local-path>``: The path to the file to upload in the local
278
278
  directory.
279
279
 
@@ -287,7 +287,7 @@ class UploadSharedFileHandler(UploadFileHandler):
287
287
  The following command uploads a file to a personal/shared space and overwrite any
288
288
  existing files at the specified path::
289
289
 
290
- UPLOAD SHARED FILE TO '/data/stats.csv'
290
+ UPLOAD SHARED FILE TO 'stats.csv'
291
291
  FROM '/tmp/user/stats.csv' OVERWRITE;
292
292
 
293
293
  See Also
@@ -7,13 +7,11 @@ from typing import Optional
7
7
  from typing import Union
8
8
 
9
9
  from ...exceptions import ManagementError
10
+ from ...management import files as mgmt_files
10
11
  from ...management import manage_workspaces
11
12
  from ...management.files import FilesManager
12
13
  from ...management.files import FileSpace
13
14
  from ...management.files import manage_files
14
- from ...management.files import MODELS_SPACE
15
- from ...management.files import PERSONAL_SPACE
16
- from ...management.files import SHARED_SPACE
17
15
  from ...management.workspace import StarterWorkspace
18
16
  from ...management.workspace import Workspace
19
17
  from ...management.workspace import WorkspaceGroup
@@ -190,6 +188,10 @@ def get_deployment(
190
188
  * params['group']['deployment_id']
191
189
  * params['in_deployment']['deployment_name']
192
190
  * params['in_deployment']['deployment_id']
191
+ * params['in']['in_group']['deployment_name']
192
+ * params['in']['in_group']['deployment_id']
193
+ * params['in']['in_deployment']['deployment_name']
194
+ * params['in']['in_deployment']['deployment_id']
193
195
 
194
196
  Or, from the SINGLESTOREDB_WORKSPACE_GROUP
195
197
  or SINGLESTOREDB_CLUSTER environment variables.
@@ -199,7 +201,9 @@ def get_deployment(
199
201
 
200
202
  deployment_name = params.get('deployment_name') or \
201
203
  (params.get('in_deployment') or {}).get('deployment_name') or \
202
- (params.get('group') or {}).get('deployment_name')
204
+ (params.get('group') or {}).get('deployment_name') or \
205
+ ((params.get('in') or {}).get('in_group') or {}).get('deployment_name') or \
206
+ ((params.get('in') or {}).get('in_deployment') or {}).get('deployment_name')
203
207
  if deployment_name:
204
208
  workspace_groups = [
205
209
  x for x in manager.workspace_groups
@@ -239,7 +243,9 @@ def get_deployment(
239
243
 
240
244
  deployment_id = params.get('deployment_id') or \
241
245
  (params.get('in_deployment') or {}).get('deployment_id') or \
242
- (params.get('group') or {}).get('deployment_id')
246
+ (params.get('group') or {}).get('deployment_id') or \
247
+ ((params.get('in') or {}).get('in_group') or {}).get('deployment_id') or \
248
+ ((params.get('in') or {}).get('in_deployment') or {}).get('deployment_id')
243
249
  if deployment_id:
244
250
  try:
245
251
  return manager.get_workspace_group(deployment_id)
@@ -298,11 +304,11 @@ def get_file_space(params: Dict[str, Any]) -> FileSpace:
298
304
  if file_location:
299
305
  file_location_lower_case = file_location.lower()
300
306
 
301
- if file_location_lower_case == PERSONAL_SPACE:
307
+ if file_location_lower_case == mgmt_files.PERSONAL_SPACE:
302
308
  return manager.personal_space
303
- elif file_location_lower_case == SHARED_SPACE:
309
+ elif file_location_lower_case == mgmt_files.SHARED_SPACE:
304
310
  return manager.shared_space
305
- elif file_location_lower_case == MODELS_SPACE:
311
+ elif file_location_lower_case == mgmt_files.MODELS_SPACE:
306
312
  return manager.models_space
307
313
  else:
308
314
  raise ValueError(f'invalid file location: {file_location}')
@@ -241,7 +241,7 @@ class FusionSQLResult(object):
241
241
  for row in self.rows:
242
242
  found = True
243
243
  for i, liker in likers:
244
- if not liker.match(row[i]):
244
+ if row[i] is None or not liker.match(row[i]):
245
245
  found = False
246
246
  break
247
247
  if found:
@@ -972,6 +972,10 @@ class Connection(connection.Connection):
972
972
 
973
973
  def __init__(self, **kwargs: Any):
974
974
  from .. import __version__ as client_version
975
+
976
+ if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
977
+ client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
978
+
975
979
  connection.Connection.__init__(self, **kwargs)
976
980
 
977
981
  host = kwargs.get('host', get_option('host'))
@@ -24,6 +24,9 @@ class ExportService(object):
24
24
  catalog_info: Dict[str, Any]
25
25
  storage_info: Dict[str, Any]
26
26
  columns: Optional[List[str]]
27
+ partition_by: Optional[List[Dict[str, str]]]
28
+ order_by: Optional[List[Dict[str, Dict[str, str]]]]
29
+ properties: Optional[Dict[str, Any]]
27
30
 
28
31
  def __init__(
29
32
  self,
@@ -32,7 +35,10 @@ class ExportService(object):
32
35
  table: str,
33
36
  catalog_info: Union[str, Dict[str, Any]],
34
37
  storage_info: Union[str, Dict[str, Any]],
35
- columns: Optional[List[str]],
38
+ columns: Optional[List[str]] = None,
39
+ partition_by: Optional[List[Dict[str, str]]] = None,
40
+ order_by: Optional[List[Dict[str, Dict[str, str]]]] = None,
41
+ properties: Optional[Dict[str, Any]] = None,
36
42
  ):
37
43
  #: Workspace group
38
44
  self.workspace_group = workspace_group
@@ -58,6 +64,10 @@ class ExportService(object):
58
64
  else:
59
65
  self.storage_info = copy.copy(storage_info)
60
66
 
67
+ self.partition_by = partition_by or None
68
+ self.order_by = order_by or None
69
+ self.properties = properties or None
70
+
61
71
  self._manager: Optional[WorkspaceManager] = workspace_group._manager
62
72
 
63
73
  def __str__(self) -> str:
@@ -93,14 +103,27 @@ class ExportService(object):
93
103
  msg='No workspace manager is associated with this object.',
94
104
  )
95
105
 
106
+ partition_spec = None
107
+ if self.partition_by:
108
+ partition_spec = dict(partitions=self.partition_by)
109
+
110
+ sort_order_spec = None
111
+ if self.order_by:
112
+ sort_order_spec = dict(keys=self.order_by)
113
+
96
114
  out = self._manager._post(
97
115
  f'workspaceGroups/{self.workspace_group.id}/egress/startTableEgress',
98
- json=dict(
99
- databaseName=self.database,
100
- tableName=self.table,
101
- storageInfo=self.storage_info,
102
- catalogInfo=self.catalog_info,
103
- ),
116
+ json={
117
+ k: v for k, v in dict(
118
+ databaseName=self.database,
119
+ tableName=self.table,
120
+ storageInfo=self.storage_info,
121
+ catalogInfo=self.catalog_info,
122
+ partitionSpec=partition_spec,
123
+ sortOrderSpec=sort_order_spec,
124
+ properties=self.properties,
125
+ ).items() if v is not None
126
+ },
104
127
  )
105
128
 
106
129
  return ExportStatus(out.json()['egressID'], self.workspace_group)
@@ -10,11 +10,9 @@ import re
10
10
  from abc import ABC
11
11
  from abc import abstractmethod
12
12
  from typing import Any
13
- from typing import BinaryIO
14
13
  from typing import Dict
15
14
  from typing import List
16
15
  from typing import Optional
17
- from typing import TextIO
18
16
  from typing import Union
19
17
 
20
18
  from .. import config
@@ -362,7 +360,7 @@ class FileLocation(ABC):
362
360
  @abstractmethod
363
361
  def upload_file(
364
362
  self,
365
- local_path: Union[PathLike, TextIO, BinaryIO],
363
+ local_path: Union[PathLike, io.IOBase],
366
364
  path: PathLike,
367
365
  *,
368
366
  overwrite: bool = False,
@@ -385,7 +383,7 @@ class FileLocation(ABC):
385
383
  @abstractmethod
386
384
  def _upload(
387
385
  self,
388
- content: Union[str, bytes, TextIO, BinaryIO],
386
+ content: Union[str, bytes, io.IOBase],
389
387
  path: PathLike,
390
388
  *,
391
389
  overwrite: bool = False,
@@ -628,7 +626,7 @@ class FileSpace(FileLocation):
628
626
 
629
627
  def upload_file(
630
628
  self,
631
- local_path: Union[PathLike, TextIO, BinaryIO],
629
+ local_path: Union[PathLike, io.IOBase],
632
630
  path: PathLike,
633
631
  *,
634
632
  overwrite: bool = False,
@@ -646,7 +644,7 @@ class FileSpace(FileLocation):
646
644
  Should the ``path`` be overwritten if it exists already?
647
645
 
648
646
  """
649
- if isinstance(local_path, (TextIO, BinaryIO)):
647
+ if isinstance(local_path, io.IOBase):
650
648
  pass
651
649
  elif not os.path.isfile(local_path):
652
650
  raise IsADirectoryError(f'local path is not a file: {local_path}')
@@ -657,8 +655,9 @@ class FileSpace(FileLocation):
657
655
 
658
656
  self.remove(path)
659
657
 
660
- if isinstance(local_path, (TextIO, BinaryIO)):
658
+ if isinstance(local_path, io.IOBase):
661
659
  return self._upload(local_path, path, overwrite=overwrite)
660
+
662
661
  return self._upload(open(local_path, 'rb'), path, overwrite=overwrite)
663
662
 
664
663
  def upload_folder(
@@ -727,7 +726,7 @@ class FileSpace(FileLocation):
727
726
 
728
727
  def _upload(
729
728
  self,
730
- content: Union[str, bytes, TextIO, BinaryIO],
729
+ content: Union[str, bytes, io.IOBase],
731
730
  path: PathLike,
732
731
  *,
733
732
  overwrite: bool = False,
@@ -10,11 +10,9 @@ import re
10
10
  import time
11
11
  from collections.abc import Mapping
12
12
  from typing import Any
13
- from typing import BinaryIO
14
13
  from typing import Dict
15
14
  from typing import List
16
15
  from typing import Optional
17
- from typing import TextIO
18
16
  from typing import Union
19
17
 
20
18
  from .. import config
@@ -165,7 +163,7 @@ class Stage(FileLocation):
165
163
 
166
164
  def upload_file(
167
165
  self,
168
- local_path: Union[PathLike, TextIO, BinaryIO],
166
+ local_path: Union[PathLike, io.IOBase],
169
167
  stage_path: PathLike,
170
168
  *,
171
169
  overwrite: bool = False,
@@ -183,7 +181,7 @@ class Stage(FileLocation):
183
181
  Should the ``stage_path`` be overwritten if it exists already?
184
182
 
185
183
  """
186
- if isinstance(local_path, (TextIO, BinaryIO)):
184
+ if isinstance(local_path, io.IOBase):
187
185
  pass
188
186
  elif not os.path.isfile(local_path):
189
187
  raise IsADirectoryError(f'local path is not a file: {local_path}')
@@ -194,8 +192,9 @@ class Stage(FileLocation):
194
192
 
195
193
  self.remove(stage_path)
196
194
 
197
- if isinstance(local_path, (TextIO, BinaryIO)):
195
+ if isinstance(local_path, io.IOBase):
198
196
  return self._upload(local_path, stage_path, overwrite=overwrite)
197
+
199
198
  return self._upload(open(local_path, 'rb'), stage_path, overwrite=overwrite)
200
199
 
201
200
  def upload_folder(
@@ -258,7 +257,7 @@ class Stage(FileLocation):
258
257
 
259
258
  def _upload(
260
259
  self,
261
- content: Union[str, bytes, TextIO, BinaryIO],
260
+ content: Union[str, bytes, io.IOBase],
262
261
  stage_path: PathLike,
263
262
  *,
264
263
  overwrite: bool = False,