vastdb 1.3.9__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,294 @@
1
+ import datetime
2
+ import decimal
3
+ import itertools
4
+ import random
5
+ from typing import Any, Union, cast
6
+
7
+ import numpy as np
8
+ import pyarrow as pa
9
+ import pyarrow.compute as pc
10
+ import pytest
11
+
12
+ import vastdb.errors
13
+
14
+ from .util import prepare_data
15
+
16
+ supported_fixed_list_element_types = [
17
+ pa.uint8(),
18
+ pa.uint16(),
19
+ pa.uint32(),
20
+ pa.uint64(),
21
+ pa.int8(),
22
+ pa.int16(),
23
+ pa.int32(),
24
+ pa.int64(),
25
+ pa.float32(),
26
+ pa.float64(),
27
+ pa.decimal128(10),
28
+ pa.date32(),
29
+ pa.timestamp("s"),
30
+ pa.time32("ms"),
31
+ pa.time64("us"),
32
+ ]
33
+
34
+ # All the supported element types are supported as non-nullable.
35
+ supported_fixed_list_element_fields = [
36
+ pa.field(name="item", type=element_type, nullable=False)
37
+ for element_type in supported_fixed_list_element_types
38
+ ]
39
+
40
+ unsupported_fixed_list_element_types = [
41
+ pa.string(),
42
+ pa.list_(pa.int64()),
43
+ pa.list_(pa.int64(), 1),
44
+ pa.map_(pa.utf8(), pa.float64()),
45
+ pa.struct([("x", pa.int16())]),
46
+ pa.bool_(),
47
+ pa.binary(),
48
+ ]
49
+
50
+ unsupported_fixed_list_element_fields = [ # Nullable types are not supported.
51
+ pa.field(name="item", type=element_type, nullable=True)
52
+ for element_type in itertools.chain(
53
+ supported_fixed_list_element_types, unsupported_fixed_list_element_types
54
+ )
55
+ ] + [ # Not nullable unsupported type are unsupported.
56
+ pa.field(name="item", type=element_type, nullable=False)
57
+ for element_type in unsupported_fixed_list_element_types
58
+ ]
59
+
60
+ unsupported_fixed_list_types = (
61
+ [
62
+ pa.list_(element_field, 1)
63
+ for element_field in unsupported_fixed_list_element_fields
64
+ ] +
65
+ # Fixed list with amount of elements exceeding the supported limit.
66
+ [pa.list_(
67
+ pa.field("item", pa.int64(), nullable=False), np.iinfo(np.int32).max
68
+ )]
69
+ )
70
+
71
+ invalid_fixed_list_types = [
72
+ # Fixed list 0 elements.
73
+ pa.list_(pa.field("item", pa.int64(), nullable=False), 0),
74
+ ]
75
+
76
+
77
+ def test_vectors(session, clean_bucket_name):
78
+ """
79
+ Test table with efficient vector type - pa.FixedSizeListArray[not nullable numeric].
80
+ """
81
+ dimension = 100
82
+ element_type = pa.float32()
83
+ num_rows = 50
84
+
85
+ columns = pa.schema(
86
+ [("id", pa.int64()), ("vec", pa.list_(pa.field(name="item", type=element_type, nullable=False), dimension),)]
87
+ )
88
+ ids = range(num_rows)
89
+ expected = pa.table(
90
+ schema=columns,
91
+ data=[
92
+ ids,
93
+ [[i] * dimension for i in ids],
94
+ ],
95
+ )
96
+
97
+ with prepare_data(session, clean_bucket_name, "s", "t", expected) as t:
98
+ assert t.arrow_schema == columns
99
+
100
+ # Full scan.
101
+ actual = t.select().read_all()
102
+ assert actual == expected
103
+
104
+ # Select by id.
105
+ select_id = random.randint(0, num_rows)
106
+ actual = t.select(predicate=(t["id"] == select_id)).read_all()
107
+ assert actual.to_pydict()["vec"] == [[select_id] * dimension]
108
+ assert actual == expected.filter(pc.field("id") == select_id)
109
+
110
+
111
+ def convert_scalar_type_pyarrow_to_numpy(arrow_type: pa.DataType):
112
+ return pa.array([], type=arrow_type).to_numpy().dtype.type
113
+
114
+
115
+ def generate_random_pyarrow_value(
116
+ element: Union[pa.DataType, pa.Field], nulls_prob: float = 0.2
117
+ ) -> Any:
118
+ """
119
+ Generates a random value compatible with the provided PyArrow type.
120
+
121
+ Args:
122
+ element: The pyarrow field/type to generate values for.
123
+ nulls_prob: Probability of creating nulls.
124
+ """
125
+ assert 0 <= nulls_prob <= 1
126
+
127
+ nullable = True
128
+
129
+ # Convert Field to DataType.
130
+ if isinstance(element, pa.DataType):
131
+ pa_type = element
132
+ elif isinstance(element, pa.Field):
133
+ pa_type = element.type
134
+ nullable = element.nullable
135
+ else:
136
+ raise TypeError(
137
+ f"Expected pyarrow.DataType or pyarrow.Field, got {type(element)}"
138
+ )
139
+
140
+ if nullable and random.random() < nulls_prob:
141
+ return None
142
+
143
+ if pa.types.is_boolean(pa_type):
144
+ return random.choice([True, False])
145
+ if pa.types.is_integer(pa_type):
146
+ np_type = convert_scalar_type_pyarrow_to_numpy(pa_type)
147
+ iinfo = np.iinfo(np_type)
148
+ return np.random.randint(iinfo.min, iinfo.max, dtype=np_type)
149
+ if pa.types.is_floating(pa_type):
150
+ np_type = convert_scalar_type_pyarrow_to_numpy(pa_type)
151
+ finfo = np.finfo(np_type)
152
+ return np_type(random.uniform(float(finfo.min), float(finfo.max)))
153
+ if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
154
+ return "".join(
155
+ random.choices("abcdefghijklmnopqrstuvwxyz ", k=random.randint(5, 20))
156
+ )
157
+ if pa.types.is_binary(pa_type) or pa.types.is_large_binary(pa_type):
158
+ return random.randbytes(random.randint(5, 20))
159
+ if pa.types.is_timestamp(pa_type):
160
+ # Generate a random timestamp within a range (e.g., last 10 years)
161
+ start_datetime = datetime.datetime(2015, 1, 1, tzinfo=datetime.timezone.utc)
162
+ end_datetime = datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc)
163
+ random_seconds = random.uniform(
164
+ 0, (end_datetime - start_datetime).total_seconds()
165
+ )
166
+ return start_datetime + datetime.timedelta(seconds=random_seconds)
167
+ if pa.types.is_date(pa_type):
168
+ start_date = datetime.date(2000, 1, 1)
169
+ end_date = datetime.date(2025, 1, 1)
170
+ random_days = random.randint(0, (end_date - start_date).days)
171
+ return start_date + datetime.timedelta(days=random_days)
172
+ if pa.types.is_time(pa_type):
173
+ return datetime.time(
174
+ random.randint(0, 23), random.randint(0, 59), random.randint(0, 59)
175
+ )
176
+ if pa.types.is_decimal(pa_type):
177
+ pa_type = cast(pa.Decimal128Type, pa_type)
178
+ decimal_value = decimal.Decimal(
179
+ round(random.uniform(-1000.0, 1000.0), pa_type.precision)
180
+ )
181
+ quantize_template = decimal.Decimal("1e-%d" % pa_type.scale)
182
+ return decimal_value.quantize(quantize_template)
183
+ if pa.types.is_null(pa_type): # Explicit NullType
184
+ return None
185
+ if pa.types.is_list(pa_type) or pa.types.is_fixed_size_list(pa_type):
186
+ # For ListType, recursively generate elements for the value_type
187
+ pa_type = (
188
+ cast(pa.FixedSizeListType, pa_type)
189
+ if pa.types.is_fixed_size_list(pa_type)
190
+ else cast(pa.ListType, pa_type)
191
+ )
192
+ list_size = (
193
+ pa_type.list_size
194
+ if pa.types.is_fixed_size_list(pa_type)
195
+ else random.randint(0, 5)
196
+ )
197
+ list_elements = [
198
+ generate_random_pyarrow_value(pa_type.value_field, nulls_prob)
199
+ for _ in range(list_size)
200
+ ]
201
+ return list_elements
202
+ if pa.types.is_struct(pa_type):
203
+ struct_dict = {}
204
+ for field in cast(pa.StructType, pa_type):
205
+ # Recursively generate value for each field in the struct
206
+ struct_dict[field.name] = generate_random_pyarrow_value(field, nulls_prob)
207
+ return struct_dict
208
+ if pa.types.is_map(pa_type):
209
+ num_entries = random.randint(0, 3) # Random number of map entries
210
+ pa_type = cast(pa.MapType, pa_type)
211
+ return {
212
+ generate_random_pyarrow_value(pa_type.key_field, nulls_prob): generate_random_pyarrow_value(
213
+ pa_type.item_field, nulls_prob)
214
+ for _ in range(num_entries)
215
+ }
216
+
217
+ raise NotImplementedError(
218
+ f"Generation for PyArrow type {pa_type} not implemented yet."
219
+ )
220
+
221
+
222
+ @pytest.mark.parametrize("element_field", supported_fixed_list_element_fields)
223
+ def test_fixed_list_type_values(session, clean_bucket_name, element_field):
224
+ list_size = random.randint(1, 1000)
225
+ num_rows = random.randint(1, 100)
226
+
227
+ vec_type = pa.list_(element_field, list_size)
228
+ schema = pa.schema(
229
+ {"id": pa.int64(), "vec": vec_type, "random_int": pa.int64()})
230
+ expected = pa.table(
231
+ schema=schema,
232
+ data=[list(range(num_rows))] + [[generate_random_pyarrow_value(schema.field(col_name)) for _ in range(num_rows)]
233
+ for col_name in
234
+ schema.names[1:]],
235
+ )
236
+
237
+ with prepare_data(session, clean_bucket_name, "s", "t", expected) as table:
238
+ assert table.arrow_schema == schema
239
+ actual = table.select().read_all()
240
+ assert actual == expected
241
+
242
+
243
+ @pytest.mark.parametrize("list_type", unsupported_fixed_list_types)
244
+ def test_unsupported_fixed_list_types(session, clean_bucket_name, list_type):
245
+ schema = pa.schema({"fixed_list": list_type})
246
+ empty_table = pa.table(schema=schema, data=[[]])
247
+
248
+ with pytest.raises((vastdb.errors.BadRequest, vastdb.errors.NotSupported), match=r'TabularUnsupportedColumnType'):
249
+ with prepare_data(session, clean_bucket_name, "s", "t", empty_table):
250
+ pass
251
+
252
+
253
+ @pytest.mark.parametrize("list_type", invalid_fixed_list_types)
254
+ def test_invalid_fixed_list_types(session, clean_bucket_name, list_type):
255
+ schema = pa.schema({"fixed_list": list_type})
256
+ empty_table = pa.table(schema=schema, data=[[]])
257
+
258
+ with pytest.raises(vastdb.errors.BadRequest, match=r'TabularInvalidColumnTypeParam'):
259
+ with prepare_data(session, clean_bucket_name, "s", "t", empty_table):
260
+ pass
261
+
262
+
263
+ def test_invalid_values_fixed_list(session, clean_bucket_name):
264
+ dimension = 10
265
+ element_type = pa.float32()
266
+
267
+ col_name = "vec"
268
+ schema = pa.schema([(col_name, pa.list_(pa.field(name="item", type=element_type, nullable=False), dimension))])
269
+ empty_table = pa.table(schema=schema, data=[[]])
270
+
271
+ with prepare_data(session, clean_bucket_name, "s", "t", empty_table) as table:
272
+ invalid_fields = [
273
+ pa.field(col_name, pa.list_(pa.field(name="item", type=element_type, nullable=False), dimension - 1)),
274
+ pa.field(col_name, pa.list_(pa.field(name="item", type=element_type, nullable=False), dimension + 1)),
275
+ pa.field(col_name, pa.list_(pa.field(name="item", type=element_type, nullable=True), dimension)),
276
+ schema.field(0).with_nullable(False),
277
+ ]
278
+ for field in invalid_fields:
279
+ # Everything that could be null should be in order to be invalid regarding the values and not just the type.
280
+ rb = pa.record_batch(
281
+ schema=pa.schema([field]),
282
+ data=[[[1] * field.type.list_size]]
283
+ )
284
+ with pytest.raises((vastdb.errors.BadRequest, vastdb.errors.NotFound, vastdb.errors.NotSupported),
285
+ match=r'(TabularInvalidColumnTypeParam)|(TabularUnsupportedColumnType)|(TabularMismatchColumnType)'):
286
+ table.insert(rb)
287
+
288
+ # Amount of elements in fixed list is not equal to the list size is enforced by Arrow.
289
+ with pytest.raises(pa.ArrowInvalid):
290
+ # Insert with empty list.
291
+ pa.record_batch(
292
+ schema=schema,
293
+ data=[[[generate_random_pyarrow_value(element_type, 0) for _ in range(dimension + 1)]]],
294
+ )