ydb-sqlalchemy 0.1.14__tar.gz → 0.1.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {ydb_sqlalchemy-0.1.14/ydb_sqlalchemy.egg-info → ydb_sqlalchemy-0.1.16}/PKG-INFO +1 -1
  2. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/setup.py +1 -1
  3. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/test_core.py +190 -8
  4. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/test_suite.py +0 -1
  5. ydb_sqlalchemy-0.1.16/ydb_sqlalchemy/_version.py +1 -0
  6. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/__init__.py +4 -0
  7. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/compiler/base.py +39 -20
  8. ydb_sqlalchemy-0.1.16/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py +145 -0
  9. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/types.py +91 -2
  10. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16/ydb_sqlalchemy.egg-info}/PKG-INFO +1 -1
  11. ydb_sqlalchemy-0.1.14/ydb_sqlalchemy/_version.py +0 -1
  12. ydb_sqlalchemy-0.1.14/ydb_sqlalchemy/sqlalchemy/test_sqlalchemy.py +0 -37
  13. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/LICENSE +0 -0
  14. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/MANIFEST.in +0 -0
  15. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/README.md +0 -0
  16. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/pyproject.toml +0 -0
  17. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/requirements.txt +0 -0
  18. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/setup.cfg +0 -0
  19. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/__init__.py +0 -0
  20. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/conftest.py +0 -0
  21. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/test_inspect.py +0 -0
  22. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/test/test_orm.py +0 -0
  23. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/__init__.py +0 -0
  24. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py +0 -0
  25. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py +0 -0
  26. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py +0 -0
  27. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/datetime_types.py +0 -0
  28. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py +0 -0
  29. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/dml.py +0 -0
  30. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/json.py +0 -0
  31. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy/sqlalchemy/requirements.py +0 -0
  32. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy.egg-info/SOURCES.txt +0 -0
  33. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy.egg-info/dependency_links.txt +0 -0
  34. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy.egg-info/entry_points.txt +0 -0
  35. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy.egg-info/requires.txt +0 -0
  36. {ydb_sqlalchemy-0.1.14 → ydb_sqlalchemy-0.1.16}/ydb_sqlalchemy.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ydb-sqlalchemy
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: YDB Dialect for SQLAlchemy
5
5
  Home-page: http://github.com/ydb-platform/ydb-sqlalchemy
6
6
  Author: Yandex LLC
@@ -13,7 +13,7 @@ with open("requirements.txt") as f:
13
13
 
14
14
  setuptools.setup(
15
15
  name="ydb-sqlalchemy",
16
- version="0.1.14", # AUTOVERSION
16
+ version="0.1.16", # AUTOVERSION
17
17
  description="YDB Dialect for SQLAlchemy",
18
18
  author="Yandex LLC",
19
19
  author_email="ydb@yandex-team.ru",
@@ -181,12 +181,19 @@ class TestSimpleSelect(TablesTest):
181
181
  rows = connection.execute(stm).fetchall()
182
182
  assert set(rows) == {(1,), (2,), (3,), (4,), (6,), (7,)}
183
183
 
184
+ # LIMIT
185
+ rows = connection.execute(tb.select().order_by(tb.c.id).limit(2)).fetchall()
186
+ assert rows == [
187
+ (1, "some text", Decimal("3.141592653")),
188
+ (2, "test text", Decimal("3.14159265")),
189
+ ]
190
+
184
191
  # LIMIT/OFFSET
185
- # rows = connection.execute(tb.select().order_by(tb.c.id).limit(2)).fetchall()
186
- # assert rows == [
187
- # (1, "some text", Decimal("3.141592653")),
188
- # (2, "test text", Decimal("3.14159265")),
189
- # ]
192
+ rows = connection.execute(tb.select().order_by(tb.c.id).limit(2).offset(1)).fetchall()
193
+ assert rows == [
194
+ (2, "test text", Decimal("3.14159265")),
195
+ (3, "test test", Decimal("3.1415926")),
196
+ ]
190
197
 
191
198
  # ORDER BY ASC
192
199
  rows = connection.execute(sa.select(tb.c.id).order_by(tb.c.id)).fetchall()
@@ -223,11 +230,20 @@ class TestTypes(TablesTest):
223
230
  "test_primitive_types",
224
231
  metadata,
225
232
  Column("int", sa.Integer, primary_key=True),
226
- # Column("bin", sa.BINARY),
233
+ Column("bin", sa.BINARY),
227
234
  Column("str", sa.String),
228
235
  Column("float", sa.Float),
229
236
  Column("bool", sa.Boolean),
230
237
  )
238
+ Table(
239
+ "test_all_binary_types",
240
+ metadata,
241
+ Column("id", sa.Integer, primary_key=True),
242
+ Column("bin", sa.BINARY),
243
+ Column("large_bin", sa.LargeBinary),
244
+ Column("blob", sa.BLOB),
245
+ Column("custom_bin", types.Binary),
246
+ )
231
247
  Table(
232
248
  "test_datetime_types",
233
249
  metadata,
@@ -244,7 +260,7 @@ class TestTypes(TablesTest):
244
260
 
245
261
  statement = sa.insert(table).values(
246
262
  int=42,
247
- # bin=b"abc",
263
+ bin=b"abc",
248
264
  str="Hello World!",
249
265
  float=3.5,
250
266
  bool=True,
@@ -253,7 +269,22 @@ class TestTypes(TablesTest):
253
269
  connection.execute(statement)
254
270
 
255
271
  row = connection.execute(sa.select(table)).fetchone()
256
- assert row == (42, "Hello World!", 3.5, True)
272
+ assert row == (42, b"abc", "Hello World!", 3.5, True)
273
+
274
+ def test_all_binary_types(self, connection):
275
+ table = self.tables.test_all_binary_types
276
+ data = {
277
+ "id": 1,
278
+ "bin": b"binary",
279
+ "large_bin": b"large_binary",
280
+ "blob": b"blob",
281
+ "custom_bin": b"custom_binary",
282
+ }
283
+ statement = sa.insert(table).values(**data)
284
+ connection.execute(statement)
285
+
286
+ row = connection.execute(sa.select(table)).fetchone()
287
+ assert row == (1, b"binary", b"large_binary", b"blob", b"custom_binary")
257
288
 
258
289
  def test_integer_types(self, connection):
259
290
  stmt = sa.select(
@@ -1112,3 +1143,154 @@ class TestTablePathPrefix(TablesTest):
1112
1143
  metadata.reflect(reflection_engine)
1113
1144
 
1114
1145
  assert "nested_dir/table" in metadata.tables
1146
+
1147
+
1148
+ class TestAsTable(TablesTest):
1149
+ __backend__ = True
1150
+
1151
+ @classmethod
1152
+ def define_tables(cls, metadata):
1153
+ Table(
1154
+ "test_as_table",
1155
+ metadata,
1156
+ Column("id", Integer, primary_key=True),
1157
+ Column("val_int", Integer, nullable=True),
1158
+ Column("val_str", String, nullable=True),
1159
+ )
1160
+ Table(
1161
+ "test_as_table_json",
1162
+ metadata,
1163
+ Column("id", Integer, primary_key=True),
1164
+ Column("data", sa.JSON, nullable=True),
1165
+ )
1166
+
1167
+ @pytest.mark.parametrize("list_cls", [types.ListType, sa.ARRAY])
1168
+ def test_upsert_as_table(self, connection, list_cls):
1169
+ table = self.tables.test_as_table
1170
+
1171
+ input_data = [
1172
+ {"id": 1, "val_int": 10, "val_str": "a"},
1173
+ {"id": 2, "val_int": None, "val_str": "b"},
1174
+ {"id": 3, "val_int": 30, "val_str": None},
1175
+ ]
1176
+
1177
+ struct_type = types.StructType(
1178
+ {
1179
+ "id": Integer,
1180
+ "val_int": types.Optional(Integer),
1181
+ "val_str": types.Optional(String),
1182
+ }
1183
+ )
1184
+ list_type = list_cls(struct_type)
1185
+
1186
+ bind_param = sa.bindparam("data", type_=list_type)
1187
+
1188
+ upsert_stm = ydb_sa.upsert(table).from_select(
1189
+ ["id", "val_int", "val_str"],
1190
+ sa.select(
1191
+ sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String)
1192
+ ).select_from(sa.func.AS_TABLE(bind_param)),
1193
+ )
1194
+
1195
+ connection.execute(upsert_stm, {"data": input_data})
1196
+
1197
+ rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall()
1198
+ assert rows == [
1199
+ (1, 10, "a"),
1200
+ (2, None, "b"),
1201
+ (3, 30, None),
1202
+ ]
1203
+
1204
+ @pytest.mark.parametrize("list_cls", [types.ListType, sa.ARRAY])
1205
+ def test_upsert_from_table_json(self, connection, list_cls):
1206
+ table = self.tables.test_as_table_json
1207
+
1208
+ input_data = [
1209
+ {"id": 1, "data": {"a": 1}},
1210
+ {"id": 2, "data": [1, 2, 3]},
1211
+ {"id": 3, "data": None},
1212
+ ]
1213
+
1214
+ struct_type = types.StructType.from_table(table)
1215
+ list_type = list_cls(struct_type)
1216
+
1217
+ bind_param = sa.bindparam("input_data", type_=list_type)
1218
+
1219
+ cols = [sa.column(c.name, type_=c.type) for c in table.columns]
1220
+ upsert_stm = ydb_sa.upsert(table).from_select(
1221
+ [c.name for c in table.columns],
1222
+ sa.select(*cols).select_from(sa.func.AS_TABLE(bind_param)),
1223
+ )
1224
+
1225
+ connection.execute(upsert_stm, {"input_data": input_data})
1226
+
1227
+ rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall()
1228
+
1229
+ assert rows == [
1230
+ (1, {"a": 1}),
1231
+ (2, [1, 2, 3]),
1232
+ (3, None),
1233
+ ]
1234
+
1235
+ @pytest.mark.parametrize("list_cls", [types.ListType, sa.ARRAY])
1236
+ def test_insert_as_table(self, connection, list_cls):
1237
+ table = self.tables.test_as_table
1238
+
1239
+ input_data = [
1240
+ {"id": 4, "val_int": 40, "val_str": "d"},
1241
+ {"id": 5, "val_int": None, "val_str": "e"},
1242
+ ]
1243
+
1244
+ struct_type = types.StructType(
1245
+ {
1246
+ "id": Integer,
1247
+ "val_int": types.Optional(Integer),
1248
+ "val_str": types.Optional(String),
1249
+ }
1250
+ )
1251
+ list_type = list_cls(struct_type)
1252
+
1253
+ bind_param = sa.bindparam("data", type_=list_type)
1254
+
1255
+ insert_stm = sa.insert(table).from_select(
1256
+ ["id", "val_int", "val_str"],
1257
+ sa.select(
1258
+ sa.column("id", type_=Integer), sa.column("val_int", type_=Integer), sa.column("val_str", type_=String)
1259
+ ).select_from(sa.func.AS_TABLE(bind_param)),
1260
+ )
1261
+
1262
+ connection.execute(insert_stm, {"data": input_data})
1263
+
1264
+ rows = connection.execute(sa.select(table).where(table.c.id >= 4).order_by(table.c.id)).fetchall()
1265
+ assert rows == [
1266
+ (4, 40, "d"),
1267
+ (5, None, "e"),
1268
+ ]
1269
+
1270
+ @pytest.mark.parametrize("list_cls", [types.ListType, sa.ARRAY])
1271
+ def test_upsert_from_table_reflection(self, connection, list_cls):
1272
+ table = self.tables.test_as_table
1273
+
1274
+ input_data = [
1275
+ {"id": 1, "val_int": 10, "val_str": "a"},
1276
+ {"id": 2, "val_int": None, "val_str": "b"},
1277
+ ]
1278
+
1279
+ struct_type = types.StructType.from_table(table)
1280
+ list_type = list_cls(struct_type)
1281
+
1282
+ bind_param = sa.bindparam("data", type_=list_type)
1283
+
1284
+ cols = [sa.column(c.name, type_=c.type) for c in table.columns]
1285
+ upsert_stm = ydb_sa.upsert(table).from_select(
1286
+ [c.name for c in table.columns],
1287
+ sa.select(*cols).select_from(sa.func.AS_TABLE(bind_param)),
1288
+ )
1289
+
1290
+ connection.execute(upsert_stm, {"data": input_data})
1291
+
1292
+ rows = connection.execute(sa.select(table).order_by(table.c.id)).fetchall()
1293
+ assert rows == [
1294
+ (1, 10, "a"),
1295
+ (2, None, "b"),
1296
+ ]
@@ -274,7 +274,6 @@ class NumericTest(_NumericTest):
274
274
  pass
275
275
 
276
276
 
277
- @pytest.mark.skip("TODO: see issue #13")
278
277
  class BinaryTest(_BinaryTest):
279
278
  pass
280
279
 
@@ -0,0 +1 @@
1
+ VERSION = "0.1.16"
@@ -157,6 +157,10 @@ class YqlDialect(StrCompileDialect):
157
157
  sa.types.DATETIME: types.YqlDateTime,
158
158
  sa.types.TIMESTAMP: types.YqlTimestamp,
159
159
  sa.types.DECIMAL: types.Decimal,
160
+ sa.types.BINARY: types.Binary,
161
+ sa.types.LargeBinary: types.Binary,
162
+ sa.types.BLOB: types.Binary,
163
+ sa.types.ARRAY: types.ListType,
160
164
  }
161
165
 
162
166
  connection_characteristics = util.immutabledict(
@@ -12,6 +12,7 @@ from sqlalchemy.sql.compiler import (
12
12
  StrSQLTypeCompiler,
13
13
  selectable,
14
14
  )
15
+ from sqlalchemy.sql.type_api import to_instance
15
16
  from typing import (
16
17
  Any,
17
18
  Dict,
@@ -24,6 +25,12 @@ from typing import (
24
25
  Union,
25
26
  )
26
27
 
28
+ try:
29
+ from sqlalchemy.types import _Binary as _BinaryType
30
+ except ImportError:
31
+ # For older sqlalchemy versions
32
+ from sqlalchemy.sql.sqltypes import _Binary as _BinaryType
33
+
27
34
 
28
35
  from .. import types
29
36
 
@@ -37,19 +44,6 @@ else:
37
44
  from sqlalchemy import Cast as _cast
38
45
 
39
46
 
40
- STR_QUOTE_MAP = {
41
- "'": "\\'",
42
- "\\": "\\\\",
43
- "\0": "\\0",
44
- "\b": "\\b",
45
- "\f": "\\f",
46
- "\r": "\\r",
47
- "\n": "\\n",
48
- "\t": "\\t",
49
- "%": "%%",
50
- }
51
-
52
-
53
47
  COMPOUND_KEYWORDS = {
54
48
  selectable.CompoundSelect.UNION: "UNION ALL",
55
49
  selectable.CompoundSelect.UNION_ALL: "UNION ALL",
@@ -60,6 +54,19 @@ COMPOUND_KEYWORDS = {
60
54
  }
61
55
 
62
56
 
57
+ ESCAPE_RULES = [
58
+ ("\\", "\\\\"), # Must be first to avoid double escaping
59
+ ("'", "\\'"),
60
+ ("\0", "\\0"),
61
+ ("\b", "\\b"),
62
+ ("\f", "\\f"),
63
+ ("\r", "\\r"),
64
+ ("\n", "\\n"),
65
+ ("\t", "\\t"),
66
+ ("%", "%%"),
67
+ ]
68
+
69
+
63
70
  class BaseYqlTypeCompiler(StrSQLTypeCompiler):
64
71
  def visit_JSON(self, type_: Union[sa.JSON, types.YqlJSON], **kw):
65
72
  return "JSON"
@@ -152,11 +159,17 @@ class BaseYqlTypeCompiler(StrSQLTypeCompiler):
152
159
  inner = self.process(type_.item_type, **kw)
153
160
  return f"List<{inner}>"
154
161
 
162
+ def visit_optional(self, type_: types.Optional, **kw):
163
+ el = to_instance(type_.element_type)
164
+ inner = self.process(el, **kw)
165
+ return f"Optional<{inner}>"
166
+
155
167
  def visit_struct_type(self, type_: types.StructType, **kw):
156
- text = "Struct<"
157
- for field, field_type in type_.fields_types:
158
- text += f"{field}:{self.process(field_type, **kw)}"
159
- return text + ">"
168
+ rendered_types = []
169
+ for field, field_type in type_.fields_types.items():
170
+ type_str = self.process(field_type, **kw)
171
+ rendered_types.append(f"{field}:{type_str}")
172
+ return f"Struct<{','.join(rendered_types)}>"
160
173
 
161
174
  def get_ydb_type(
162
175
  self, type_: sa.types.TypeEngine, is_optional: bool
@@ -167,6 +180,10 @@ class BaseYqlTypeCompiler(StrSQLTypeCompiler):
167
180
  if isinstance(type_, (sa.Text, sa.String)):
168
181
  ydb_type = ydb.PrimitiveType.Utf8
169
182
 
183
+ elif isinstance(type_, types.Optional):
184
+ inner = to_instance(type_.element_type)
185
+ return self.get_ydb_type(inner, is_optional=True)
186
+
170
187
  # Integers
171
188
  elif isinstance(type_, types.UInt64):
172
189
  ydb_type = ydb.PrimitiveType.Uint64
@@ -216,7 +233,7 @@ class BaseYqlTypeCompiler(StrSQLTypeCompiler):
216
233
  ydb_type = ydb.PrimitiveType.Timestamp
217
234
  elif isinstance(type_, sa.Date):
218
235
  ydb_type = ydb.PrimitiveType.Date
219
- elif isinstance(type_, sa.BINARY):
236
+ elif isinstance(type_, _BinaryType):
220
237
  ydb_type = ydb.PrimitiveType.String
221
238
  elif isinstance(type_, sa.Float):
222
239
  ydb_type = ydb.PrimitiveType.Float
@@ -235,7 +252,8 @@ class BaseYqlTypeCompiler(StrSQLTypeCompiler):
235
252
  elif isinstance(type_, types.StructType):
236
253
  ydb_type = ydb.StructType()
237
254
  for field, field_type in type_.fields_types.items():
238
- ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False))
255
+ inner_type = to_instance(field_type)
256
+ ydb_type.add_member(field, self.get_ydb_type(inner_type, is_optional=False))
239
257
  else:
240
258
  raise NotSupportedError(f"{type_} bind variables not supported")
241
259
 
@@ -275,7 +293,8 @@ class BaseYqlCompiler(StrSQLCompiler):
275
293
 
276
294
  def render_literal_value(self, value, type_):
277
295
  if isinstance(value, str):
278
- value = "".join(STR_QUOTE_MAP.get(x, x) for x in value)
296
+ for pattern, replacement in ESCAPE_RULES:
297
+ value = value.replace(pattern, replacement)
279
298
  return f"'{value}'"
280
299
  return super().render_literal_value(value, type_)
281
300
 
@@ -0,0 +1,145 @@
1
+ from datetime import date
2
+ import sqlalchemy as sa
3
+
4
+ from . import YqlDialect, types
5
+
6
+
7
+ def test_casts():
8
+ dialect = YqlDialect()
9
+ expr = sa.literal_column("1/2")
10
+
11
+ res_exprs = [
12
+ sa.cast(expr, types.UInt32),
13
+ sa.cast(expr, types.UInt64),
14
+ sa.cast(expr, types.UInt8),
15
+ sa.func.String.JoinFromList(
16
+ sa.func.ListMap(sa.func.TOPFREQ(expr, 5), types.Lambda(lambda x: sa.cast(x, sa.Text))),
17
+ ", ",
18
+ ),
19
+ ]
20
+
21
+ strs = [str(res_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) for res_expr in res_exprs]
22
+
23
+ assert strs == [
24
+ "CAST(1/2 AS UInt32)",
25
+ "CAST(1/2 AS UInt64)",
26
+ "CAST(1/2 AS UInt8)",
27
+ "String::JoinFromList(ListMap(TOPFREQ(1/2, 5), ($x) -> { RETURN CAST($x AS UTF8) ;}), ', ')",
28
+ ]
29
+
30
+
31
+ def test_ydb_types():
32
+ dialect = YqlDialect()
33
+
34
+ query = sa.literal(date(1996, 11, 19))
35
+ compiled = query.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
36
+
37
+ assert str(compiled) == "Date('1996-11-19')"
38
+
39
+
40
+ def test_binary_type():
41
+ dialect = YqlDialect()
42
+ expr = sa.literal(b"some bytes")
43
+ compiled = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
44
+ assert str(compiled) == "'some bytes'"
45
+
46
+ expr_binary = sa.cast(expr, sa.BINARY)
47
+ compiled_binary = expr_binary.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
48
+ assert str(compiled_binary) == "CAST('some bytes' AS String)"
49
+
50
+
51
+ def test_all_binary_types():
52
+ dialect = YqlDialect()
53
+ expr = sa.literal(b"some bytes")
54
+
55
+ binary_types = [
56
+ sa.BINARY,
57
+ sa.LargeBinary,
58
+ sa.BLOB,
59
+ types.Binary,
60
+ ]
61
+
62
+ for type_ in binary_types:
63
+ expr_binary = sa.cast(expr, type_)
64
+ compiled_binary = expr_binary.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
65
+ assert str(compiled_binary) == "CAST('some bytes' AS String)"
66
+
67
+
68
+ def test_struct_type_generation():
69
+ dialect = YqlDialect()
70
+ type_compiler = dialect.type_compiler
71
+
72
+ # Test default (non-optional)
73
+ struct_type = types.StructType(
74
+ {
75
+ "id": sa.Integer,
76
+ "val_int": sa.Integer,
77
+ }
78
+ )
79
+ ydb_type = type_compiler.get_ydb_type(struct_type, is_optional=False)
80
+ # Keys are sorted
81
+ assert str(ydb_type) == "Struct<id:Int64,val_int:Int64>"
82
+
83
+ # Test optional
84
+ struct_type_opt = types.StructType(
85
+ {
86
+ "id": sa.Integer,
87
+ "val_int": types.Optional(sa.Integer),
88
+ }
89
+ )
90
+ ydb_type_opt = type_compiler.get_ydb_type(struct_type_opt, is_optional=False)
91
+ assert str(ydb_type_opt) == "Struct<id:Int64,val_int:Int64?>"
92
+
93
+
94
+ def test_types_compilation():
95
+ dialect = YqlDialect()
96
+
97
+ def compile_type(type_):
98
+ return dialect.type_compiler.process(type_)
99
+
100
+ assert compile_type(types.UInt64()) == "UInt64"
101
+ assert compile_type(types.UInt32()) == "UInt32"
102
+ assert compile_type(types.UInt16()) == "UInt16"
103
+ assert compile_type(types.UInt8()) == "UInt8"
104
+
105
+ assert compile_type(types.Int64()) == "Int64"
106
+ assert compile_type(types.Int32()) == "Int32"
107
+ assert compile_type(types.Int16()) == "Int32"
108
+ assert compile_type(types.Int8()) == "Int8"
109
+
110
+ assert compile_type(types.ListType(types.Int64())) == "List<Int64>"
111
+
112
+ struct = types.StructType({"a": types.Int32(), "b": types.ListType(types.Int32())})
113
+ # Ordered by key: a, b
114
+ assert compile_type(struct) == "Struct<a:Int32,b:List<Int32>>"
115
+
116
+
117
+ def test_optional_type_compilation():
118
+ dialect = YqlDialect()
119
+ type_compiler = dialect.type_compiler
120
+
121
+ def compile_type(type_):
122
+ return type_compiler.process(type_)
123
+
124
+ # Test Optional(Integer)
125
+ opt_int = types.Optional(sa.Integer)
126
+ assert compile_type(opt_int) == "Optional<Int64>"
127
+
128
+ # Test Optional(String)
129
+ opt_str = types.Optional(sa.String)
130
+ assert compile_type(opt_str) == "Optional<UTF8>"
131
+
132
+ # Test Nested Optional
133
+ opt_opt_int = types.Optional(types.Optional(sa.Integer))
134
+ assert compile_type(opt_opt_int) == "Optional<Optional<Int64>>"
135
+
136
+ # Test get_ydb_type
137
+ ydb_type = type_compiler.get_ydb_type(opt_int, is_optional=False)
138
+ import ydb
139
+
140
+ assert isinstance(ydb_type, ydb.OptionalType)
141
+ # Int64 corresponds to PrimitiveType.Int64
142
+ # Note: ydb.PrimitiveType.Int64 is an enum member, but ydb_type.item is also an instance/enum?
143
+ # get_ydb_type returns ydb.PrimitiveType.Int64 (enum) wrapped in OptionalType.
144
+ # OptionalType.item is the inner type.
145
+ assert ydb_type.item == ydb.PrimitiveType.Int64
@@ -8,7 +8,7 @@ if sa_version.startswith("2."):
8
8
  else:
9
9
  from sqlalchemy.sql.expression import ColumnElement
10
10
 
11
- from sqlalchemy import ARRAY, exc, types
11
+ from sqlalchemy import ARRAY, exc, Table, types
12
12
  from sqlalchemy.sql import type_api
13
13
 
14
14
  from .datetime_types import YqlDate, YqlDateTime, YqlTimestamp, YqlDate32, YqlTimestamp64, YqlDateTime64 # noqa: F401
@@ -110,18 +110,74 @@ class Decimal(types.DECIMAL):
110
110
  class ListType(ARRAY):
111
111
  __visit_name__ = "list_type"
112
112
 
113
+ def bind_processor(self, dialect):
114
+ item_proc = self.item_type.bind_processor(dialect)
115
+
116
+ def process(value):
117
+ if value is None:
118
+ return None
119
+ return [item_proc(v) if v is not None else None for v in value]
120
+
121
+ if item_proc:
122
+ return process
123
+ return None
124
+
113
125
 
114
126
  class HashableDict(dict):
115
127
  def __hash__(self):
116
128
  return hash(tuple(self.items()))
117
129
 
118
130
 
131
+ class Optional(types.TypeEngine):
132
+ """
133
+ Wrapper for YDB Optional type.
134
+
135
+ Used primarily within StructType to denote nullable fields.
136
+ """
137
+
138
+ __visit_name__ = "optional"
139
+
140
+ def __init__(self, element_type: Union[Type[types.TypeEngine], types.TypeEngine]):
141
+ self.element_type = element_type
142
+
143
+
119
144
  class StructType(types.TypeEngine[Mapping[str, Any]]):
145
+ """
146
+ YDB Struct type.
147
+
148
+ Represents a structured data type with named fields, mapped to a Python dictionary.
149
+ """
150
+
120
151
  __visit_name__ = "struct_type"
121
152
 
122
- def __init__(self, fields_types: Mapping[str, Union[Type[types.TypeEngine], Type[types.TypeDecorator]]]):
153
+ def __init__(
154
+ self,
155
+ fields_types: Mapping[
156
+ str,
157
+ Union[Type[types.TypeEngine], types.TypeEngine, Optional],
158
+ ],
159
+ ):
123
160
  self.fields_types = HashableDict(dict(sorted(fields_types.items())))
124
161
 
162
+ @classmethod
163
+ def from_table(cls, table: Table) -> "StructType":
164
+ """
165
+ Create a StructType definition from a SQLAlchemy Table.
166
+
167
+ Automatically wraps nullable columns in Optional.
168
+
169
+ :param table: SQLAlchemy Table object
170
+ :return: StructType instance
171
+ """
172
+ fields = {}
173
+ for col in table.columns:
174
+ t = col.type
175
+ if col.nullable:
176
+ fields[col.name] = Optional(t)
177
+ else:
178
+ fields[col.name] = t
179
+ return cls(fields)
180
+
125
181
  @property
126
182
  def python_type(self):
127
183
  return dict
@@ -129,6 +185,32 @@ class StructType(types.TypeEngine[Mapping[str, Any]]):
129
185
  def compare_values(self, x, y):
130
186
  return x == y
131
187
 
188
+ def bind_processor(self, dialect):
189
+ processors = {}
190
+ for name, type_ in self.fields_types.items():
191
+ if isinstance(type_, Optional):
192
+ type_ = type_.element_type
193
+
194
+ type_ = type_api.to_instance(type_)
195
+ proc = type_.bind_processor(dialect)
196
+ if proc:
197
+ processors[name] = proc
198
+
199
+ if not processors:
200
+ return None
201
+
202
+ def process(value):
203
+ if value is None:
204
+ return None
205
+ new_value = value.copy()
206
+ for name, proc in processors.items():
207
+ if name in new_value:
208
+ if new_value[name] is not None:
209
+ new_value[name] = proc(new_value[name])
210
+ return new_value
211
+
212
+ return process
213
+
132
214
 
133
215
  class Lambda(ColumnElement):
134
216
  __visit_name__ = "lambda"
@@ -139,3 +221,10 @@ class Lambda(ColumnElement):
139
221
 
140
222
  self.type = type_api.NULLTYPE
141
223
  self.func = func
224
+
225
+
226
+ class Binary(types.LargeBinary):
227
+ __visit_name__ = "BINARY"
228
+
229
+ def bind_processor(self, dialect):
230
+ return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ydb-sqlalchemy
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: YDB Dialect for SQLAlchemy
5
5
  Home-page: http://github.com/ydb-platform/ydb-sqlalchemy
6
6
  Author: Yandex LLC
@@ -1 +0,0 @@
1
- VERSION = "0.1.14"
@@ -1,37 +0,0 @@
1
- from datetime import date
2
- import sqlalchemy as sa
3
-
4
- from . import YqlDialect, types
5
-
6
-
7
- def test_casts():
8
- dialect = YqlDialect()
9
- expr = sa.literal_column("1/2")
10
-
11
- res_exprs = [
12
- sa.cast(expr, types.UInt32),
13
- sa.cast(expr, types.UInt64),
14
- sa.cast(expr, types.UInt8),
15
- sa.func.String.JoinFromList(
16
- sa.func.ListMap(sa.func.TOPFREQ(expr, 5), types.Lambda(lambda x: sa.cast(x, sa.Text))),
17
- ", ",
18
- ),
19
- ]
20
-
21
- strs = [str(res_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) for res_expr in res_exprs]
22
-
23
- assert strs == [
24
- "CAST(1/2 AS UInt32)",
25
- "CAST(1/2 AS UInt64)",
26
- "CAST(1/2 AS UInt8)",
27
- "String::JoinFromList(ListMap(TOPFREQ(1/2, 5), ($x) -> { RETURN CAST($x AS UTF8) ;}), ', ')",
28
- ]
29
-
30
-
31
- def test_ydb_types():
32
- dialect = YqlDialect()
33
-
34
- query = sa.literal(date(1996, 11, 19))
35
- compiled = query.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
36
-
37
- assert str(compiled) == "Date('1996-11-19')"
File without changes