cocoindex 0.1.49__cp313-cp313-win_amd64.whl → 0.1.51__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +52 -1
- cocoindex/_engine.cp313-win_amd64.pyd +0 -0
- cocoindex/cli.py +22 -4
- cocoindex/convert.py +41 -1
- cocoindex/functions.py +6 -4
- cocoindex/lib.py +1 -2
- cocoindex/setting.py +10 -6
- cocoindex/tests/test_convert.py +359 -84
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_typing.py +505 -0
- cocoindex/typing.py +92 -17
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/METADATA +1 -1
- cocoindex-0.1.51.dist-info/RECORD +28 -0
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/WHEEL +1 -1
- cocoindex/query.py +0 -115
- cocoindex-0.1.49.dist-info/RECORD +0 -27
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,505 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import datetime
|
3
|
+
import uuid
|
4
|
+
from typing import (
|
5
|
+
Annotated,
|
6
|
+
List,
|
7
|
+
Dict,
|
8
|
+
Literal,
|
9
|
+
Any,
|
10
|
+
get_args,
|
11
|
+
NamedTuple,
|
12
|
+
)
|
13
|
+
from collections.abc import Sequence, Mapping
|
14
|
+
import pytest
|
15
|
+
import numpy as np
|
16
|
+
from numpy.typing import NDArray
|
17
|
+
|
18
|
+
from cocoindex.typing import (
|
19
|
+
analyze_type_info,
|
20
|
+
Vector,
|
21
|
+
VectorInfo,
|
22
|
+
TypeKind,
|
23
|
+
TypeAttr,
|
24
|
+
Float32,
|
25
|
+
Float64,
|
26
|
+
encode_enriched_type,
|
27
|
+
AnalyzedTypeInfo,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class SimpleDataclass:
|
33
|
+
name: str
|
34
|
+
value: int
|
35
|
+
|
36
|
+
|
37
|
+
class SimpleNamedTuple(NamedTuple):
|
38
|
+
name: str
|
39
|
+
value: Any
|
40
|
+
|
41
|
+
|
42
|
+
def test_ndarray_float32_no_dim() -> None:
|
43
|
+
typ = NDArray[np.float32]
|
44
|
+
result = analyze_type_info(typ)
|
45
|
+
assert result == AnalyzedTypeInfo(
|
46
|
+
kind="Vector",
|
47
|
+
vector_info=VectorInfo(dim=None),
|
48
|
+
elem_type=Float32,
|
49
|
+
key_type=None,
|
50
|
+
struct_type=None,
|
51
|
+
np_number_type=np.float32,
|
52
|
+
attrs=None,
|
53
|
+
nullable=False,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
def test_vector_float32_no_dim() -> None:
|
58
|
+
typ = Vector[np.float32]
|
59
|
+
result = analyze_type_info(typ)
|
60
|
+
assert result == AnalyzedTypeInfo(
|
61
|
+
kind="Vector",
|
62
|
+
vector_info=VectorInfo(dim=None),
|
63
|
+
elem_type=Float32,
|
64
|
+
key_type=None,
|
65
|
+
struct_type=None,
|
66
|
+
np_number_type=np.float32,
|
67
|
+
attrs=None,
|
68
|
+
nullable=False,
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def test_ndarray_float64_with_dim() -> None:
|
73
|
+
typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
|
74
|
+
result = analyze_type_info(typ)
|
75
|
+
assert result == AnalyzedTypeInfo(
|
76
|
+
kind="Vector",
|
77
|
+
vector_info=VectorInfo(dim=128),
|
78
|
+
elem_type=Float64,
|
79
|
+
key_type=None,
|
80
|
+
struct_type=None,
|
81
|
+
np_number_type=np.float64,
|
82
|
+
attrs=None,
|
83
|
+
nullable=False,
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
def test_vector_float32_with_dim() -> None:
|
88
|
+
typ = Vector[np.float32, Literal[384]]
|
89
|
+
result = analyze_type_info(typ)
|
90
|
+
assert result == AnalyzedTypeInfo(
|
91
|
+
kind="Vector",
|
92
|
+
vector_info=VectorInfo(dim=384),
|
93
|
+
elem_type=Float32,
|
94
|
+
key_type=None,
|
95
|
+
struct_type=None,
|
96
|
+
np_number_type=np.float32,
|
97
|
+
attrs=None,
|
98
|
+
nullable=False,
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def test_ndarray_int64_no_dim() -> None:
|
103
|
+
typ = NDArray[np.int64]
|
104
|
+
result = analyze_type_info(typ)
|
105
|
+
assert result.kind == "Vector"
|
106
|
+
assert result.vector_info == VectorInfo(dim=None)
|
107
|
+
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
|
108
|
+
assert not result.nullable
|
109
|
+
|
110
|
+
|
111
|
+
def test_nullable_ndarray() -> None:
|
112
|
+
typ = NDArray[np.float32] | None
|
113
|
+
result = analyze_type_info(typ)
|
114
|
+
assert result == AnalyzedTypeInfo(
|
115
|
+
kind="Vector",
|
116
|
+
vector_info=VectorInfo(dim=None),
|
117
|
+
elem_type=Float32,
|
118
|
+
key_type=None,
|
119
|
+
struct_type=None,
|
120
|
+
np_number_type=np.float32,
|
121
|
+
attrs=None,
|
122
|
+
nullable=True,
|
123
|
+
)
|
124
|
+
|
125
|
+
|
126
|
+
def test_vector_str() -> None:
|
127
|
+
typ = Vector[str]
|
128
|
+
result = analyze_type_info(typ)
|
129
|
+
assert result.kind == "Vector"
|
130
|
+
assert result.elem_type == str
|
131
|
+
assert result.vector_info == VectorInfo(dim=None)
|
132
|
+
|
133
|
+
|
134
|
+
def test_vector_complex64() -> None:
|
135
|
+
typ = Vector[np.complex64]
|
136
|
+
result = analyze_type_info(typ)
|
137
|
+
assert result.kind == "Vector"
|
138
|
+
assert result.elem_type == np.complex64
|
139
|
+
assert result.vector_info == VectorInfo(dim=None)
|
140
|
+
|
141
|
+
|
142
|
+
def test_non_numpy_vector() -> None:
|
143
|
+
typ = Vector[float, Literal[3]]
|
144
|
+
result = analyze_type_info(typ)
|
145
|
+
assert result.kind == "Vector"
|
146
|
+
assert result.elem_type == float
|
147
|
+
assert result.vector_info == VectorInfo(dim=3)
|
148
|
+
|
149
|
+
|
150
|
+
def test_ndarray_any_dtype() -> None:
|
151
|
+
typ = NDArray[Any]
|
152
|
+
with pytest.raises(
|
153
|
+
TypeError, match="NDArray for Vector must use a concrete numpy dtype"
|
154
|
+
):
|
155
|
+
analyze_type_info(typ)
|
156
|
+
|
157
|
+
|
158
|
+
def test_list_of_primitives() -> None:
|
159
|
+
typ = List[str]
|
160
|
+
result = analyze_type_info(typ)
|
161
|
+
assert result == AnalyzedTypeInfo(
|
162
|
+
kind="Vector",
|
163
|
+
vector_info=VectorInfo(dim=None),
|
164
|
+
elem_type=str,
|
165
|
+
key_type=None,
|
166
|
+
struct_type=None,
|
167
|
+
np_number_type=None,
|
168
|
+
attrs=None,
|
169
|
+
nullable=False,
|
170
|
+
)
|
171
|
+
|
172
|
+
|
173
|
+
def test_list_of_structs() -> None:
|
174
|
+
typ = List[SimpleDataclass]
|
175
|
+
result = analyze_type_info(typ)
|
176
|
+
assert result == AnalyzedTypeInfo(
|
177
|
+
kind="LTable",
|
178
|
+
vector_info=None,
|
179
|
+
elem_type=SimpleDataclass,
|
180
|
+
key_type=None,
|
181
|
+
struct_type=None,
|
182
|
+
np_number_type=None,
|
183
|
+
attrs=None,
|
184
|
+
nullable=False,
|
185
|
+
)
|
186
|
+
|
187
|
+
|
188
|
+
def test_sequence_of_int() -> None:
|
189
|
+
typ = Sequence[int]
|
190
|
+
result = analyze_type_info(typ)
|
191
|
+
assert result == AnalyzedTypeInfo(
|
192
|
+
kind="Vector",
|
193
|
+
vector_info=VectorInfo(dim=None),
|
194
|
+
elem_type=int,
|
195
|
+
key_type=None,
|
196
|
+
struct_type=None,
|
197
|
+
np_number_type=None,
|
198
|
+
attrs=None,
|
199
|
+
nullable=False,
|
200
|
+
)
|
201
|
+
|
202
|
+
|
203
|
+
def test_list_with_vector_info() -> None:
|
204
|
+
typ = Annotated[List[int], VectorInfo(dim=5)]
|
205
|
+
result = analyze_type_info(typ)
|
206
|
+
assert result == AnalyzedTypeInfo(
|
207
|
+
kind="Vector",
|
208
|
+
vector_info=VectorInfo(dim=5),
|
209
|
+
elem_type=int,
|
210
|
+
key_type=None,
|
211
|
+
struct_type=None,
|
212
|
+
np_number_type=None,
|
213
|
+
attrs=None,
|
214
|
+
nullable=False,
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
def test_dict_str_int() -> None:
|
219
|
+
typ = Dict[str, int]
|
220
|
+
result = analyze_type_info(typ)
|
221
|
+
assert result == AnalyzedTypeInfo(
|
222
|
+
kind="KTable",
|
223
|
+
vector_info=None,
|
224
|
+
elem_type=(str, int),
|
225
|
+
key_type=None,
|
226
|
+
struct_type=None,
|
227
|
+
np_number_type=None,
|
228
|
+
attrs=None,
|
229
|
+
nullable=False,
|
230
|
+
)
|
231
|
+
|
232
|
+
|
233
|
+
def test_mapping_str_dataclass() -> None:
|
234
|
+
typ = Mapping[str, SimpleDataclass]
|
235
|
+
result = analyze_type_info(typ)
|
236
|
+
assert result == AnalyzedTypeInfo(
|
237
|
+
kind="KTable",
|
238
|
+
vector_info=None,
|
239
|
+
elem_type=(str, SimpleDataclass),
|
240
|
+
key_type=None,
|
241
|
+
struct_type=None,
|
242
|
+
np_number_type=None,
|
243
|
+
attrs=None,
|
244
|
+
nullable=False,
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
def test_dataclass() -> None:
|
249
|
+
typ = SimpleDataclass
|
250
|
+
result = analyze_type_info(typ)
|
251
|
+
assert result == AnalyzedTypeInfo(
|
252
|
+
kind="Struct",
|
253
|
+
vector_info=None,
|
254
|
+
elem_type=None,
|
255
|
+
key_type=None,
|
256
|
+
struct_type=SimpleDataclass,
|
257
|
+
np_number_type=None,
|
258
|
+
attrs=None,
|
259
|
+
nullable=False,
|
260
|
+
)
|
261
|
+
|
262
|
+
|
263
|
+
def test_named_tuple() -> None:
|
264
|
+
typ = SimpleNamedTuple
|
265
|
+
result = analyze_type_info(typ)
|
266
|
+
assert result == AnalyzedTypeInfo(
|
267
|
+
kind="Struct",
|
268
|
+
vector_info=None,
|
269
|
+
elem_type=None,
|
270
|
+
key_type=None,
|
271
|
+
struct_type=SimpleNamedTuple,
|
272
|
+
np_number_type=None,
|
273
|
+
attrs=None,
|
274
|
+
nullable=False,
|
275
|
+
)
|
276
|
+
|
277
|
+
|
278
|
+
def test_tuple_key_value() -> None:
|
279
|
+
typ = (str, int)
|
280
|
+
result = analyze_type_info(typ)
|
281
|
+
assert result == AnalyzedTypeInfo(
|
282
|
+
kind="Int64",
|
283
|
+
vector_info=None,
|
284
|
+
elem_type=None,
|
285
|
+
key_type=str,
|
286
|
+
struct_type=None,
|
287
|
+
np_number_type=None,
|
288
|
+
attrs=None,
|
289
|
+
nullable=False,
|
290
|
+
)
|
291
|
+
|
292
|
+
|
293
|
+
def test_str() -> None:
|
294
|
+
typ = str
|
295
|
+
result = analyze_type_info(typ)
|
296
|
+
assert result == AnalyzedTypeInfo(
|
297
|
+
kind="Str",
|
298
|
+
vector_info=None,
|
299
|
+
elem_type=None,
|
300
|
+
key_type=None,
|
301
|
+
struct_type=None,
|
302
|
+
np_number_type=None,
|
303
|
+
attrs=None,
|
304
|
+
nullable=False,
|
305
|
+
)
|
306
|
+
|
307
|
+
|
308
|
+
def test_bool() -> None:
|
309
|
+
typ = bool
|
310
|
+
result = analyze_type_info(typ)
|
311
|
+
assert result == AnalyzedTypeInfo(
|
312
|
+
kind="Bool",
|
313
|
+
vector_info=None,
|
314
|
+
elem_type=None,
|
315
|
+
key_type=None,
|
316
|
+
struct_type=None,
|
317
|
+
np_number_type=None,
|
318
|
+
attrs=None,
|
319
|
+
nullable=False,
|
320
|
+
)
|
321
|
+
|
322
|
+
|
323
|
+
def test_bytes() -> None:
|
324
|
+
typ = bytes
|
325
|
+
result = analyze_type_info(typ)
|
326
|
+
assert result == AnalyzedTypeInfo(
|
327
|
+
kind="Bytes",
|
328
|
+
vector_info=None,
|
329
|
+
elem_type=None,
|
330
|
+
key_type=None,
|
331
|
+
struct_type=None,
|
332
|
+
np_number_type=None,
|
333
|
+
attrs=None,
|
334
|
+
nullable=False,
|
335
|
+
)
|
336
|
+
|
337
|
+
|
338
|
+
def test_uuid() -> None:
|
339
|
+
typ = uuid.UUID
|
340
|
+
result = analyze_type_info(typ)
|
341
|
+
assert result == AnalyzedTypeInfo(
|
342
|
+
kind="Uuid",
|
343
|
+
vector_info=None,
|
344
|
+
elem_type=None,
|
345
|
+
key_type=None,
|
346
|
+
struct_type=None,
|
347
|
+
np_number_type=None,
|
348
|
+
attrs=None,
|
349
|
+
nullable=False,
|
350
|
+
)
|
351
|
+
|
352
|
+
|
353
|
+
def test_date() -> None:
|
354
|
+
typ = datetime.date
|
355
|
+
result = analyze_type_info(typ)
|
356
|
+
assert result == AnalyzedTypeInfo(
|
357
|
+
kind="Date",
|
358
|
+
vector_info=None,
|
359
|
+
elem_type=None,
|
360
|
+
key_type=None,
|
361
|
+
struct_type=None,
|
362
|
+
np_number_type=None,
|
363
|
+
attrs=None,
|
364
|
+
nullable=False,
|
365
|
+
)
|
366
|
+
|
367
|
+
|
368
|
+
def test_time() -> None:
|
369
|
+
typ = datetime.time
|
370
|
+
result = analyze_type_info(typ)
|
371
|
+
assert result == AnalyzedTypeInfo(
|
372
|
+
kind="Time",
|
373
|
+
vector_info=None,
|
374
|
+
elem_type=None,
|
375
|
+
key_type=None,
|
376
|
+
struct_type=None,
|
377
|
+
np_number_type=None,
|
378
|
+
attrs=None,
|
379
|
+
nullable=False,
|
380
|
+
)
|
381
|
+
|
382
|
+
|
383
|
+
def test_timedelta() -> None:
|
384
|
+
typ = datetime.timedelta
|
385
|
+
result = analyze_type_info(typ)
|
386
|
+
assert result == AnalyzedTypeInfo(
|
387
|
+
kind="TimeDelta",
|
388
|
+
vector_info=None,
|
389
|
+
elem_type=None,
|
390
|
+
key_type=None,
|
391
|
+
struct_type=None,
|
392
|
+
np_number_type=None,
|
393
|
+
attrs=None,
|
394
|
+
nullable=False,
|
395
|
+
)
|
396
|
+
|
397
|
+
|
398
|
+
def test_float() -> None:
|
399
|
+
typ = float
|
400
|
+
result = analyze_type_info(typ)
|
401
|
+
assert result == AnalyzedTypeInfo(
|
402
|
+
kind="Float64",
|
403
|
+
vector_info=None,
|
404
|
+
elem_type=None,
|
405
|
+
key_type=None,
|
406
|
+
struct_type=None,
|
407
|
+
np_number_type=None,
|
408
|
+
attrs=None,
|
409
|
+
nullable=False,
|
410
|
+
)
|
411
|
+
|
412
|
+
|
413
|
+
def test_int() -> None:
|
414
|
+
typ = int
|
415
|
+
result = analyze_type_info(typ)
|
416
|
+
assert result == AnalyzedTypeInfo(
|
417
|
+
kind="Int64",
|
418
|
+
vector_info=None,
|
419
|
+
elem_type=None,
|
420
|
+
key_type=None,
|
421
|
+
struct_type=None,
|
422
|
+
np_number_type=None,
|
423
|
+
attrs=None,
|
424
|
+
nullable=False,
|
425
|
+
)
|
426
|
+
|
427
|
+
|
428
|
+
def test_type_with_attributes() -> None:
|
429
|
+
typ = Annotated[str, TypeAttr("key", "value")]
|
430
|
+
result = analyze_type_info(typ)
|
431
|
+
assert result == AnalyzedTypeInfo(
|
432
|
+
kind="Str",
|
433
|
+
vector_info=None,
|
434
|
+
elem_type=None,
|
435
|
+
key_type=None,
|
436
|
+
struct_type=None,
|
437
|
+
np_number_type=None,
|
438
|
+
attrs={"key": "value"},
|
439
|
+
nullable=False,
|
440
|
+
)
|
441
|
+
|
442
|
+
|
443
|
+
def test_encode_enriched_type_none() -> None:
|
444
|
+
typ = None
|
445
|
+
result = encode_enriched_type(typ)
|
446
|
+
assert result is None
|
447
|
+
|
448
|
+
|
449
|
+
def test_encode_enriched_type_struct() -> None:
|
450
|
+
typ = SimpleDataclass
|
451
|
+
result = encode_enriched_type(typ)
|
452
|
+
assert result["type"]["kind"] == "Struct"
|
453
|
+
assert len(result["type"]["fields"]) == 2
|
454
|
+
assert result["type"]["fields"][0]["name"] == "name"
|
455
|
+
assert result["type"]["fields"][0]["type"]["kind"] == "Str"
|
456
|
+
assert result["type"]["fields"][1]["name"] == "value"
|
457
|
+
assert result["type"]["fields"][1]["type"]["kind"] == "Int64"
|
458
|
+
|
459
|
+
|
460
|
+
def test_encode_enriched_type_vector() -> None:
|
461
|
+
typ = NDArray[np.float32]
|
462
|
+
result = encode_enriched_type(typ)
|
463
|
+
assert result["type"]["kind"] == "Vector"
|
464
|
+
assert result["type"]["element_type"]["kind"] == "Float32"
|
465
|
+
assert result["type"]["dimension"] is None
|
466
|
+
|
467
|
+
|
468
|
+
def test_encode_enriched_type_ltable() -> None:
|
469
|
+
typ = List[SimpleDataclass]
|
470
|
+
result = encode_enriched_type(typ)
|
471
|
+
assert result["type"]["kind"] == "LTable"
|
472
|
+
assert result["type"]["row"]["kind"] == "Struct"
|
473
|
+
assert len(result["type"]["row"]["fields"]) == 2
|
474
|
+
|
475
|
+
|
476
|
+
def test_encode_enriched_type_with_attrs() -> None:
|
477
|
+
typ = Annotated[str, TypeAttr("key", "value")]
|
478
|
+
result = encode_enriched_type(typ)
|
479
|
+
assert result["type"]["kind"] == "Str"
|
480
|
+
assert result["attrs"] == {"key": "value"}
|
481
|
+
|
482
|
+
|
483
|
+
def test_encode_enriched_type_nullable() -> None:
|
484
|
+
typ = str | None
|
485
|
+
result = encode_enriched_type(typ)
|
486
|
+
assert result["type"]["kind"] == "Str"
|
487
|
+
assert result["nullable"] is True
|
488
|
+
|
489
|
+
|
490
|
+
def test_invalid_struct_kind() -> None:
|
491
|
+
typ = Annotated[SimpleDataclass, TypeKind("Vector")]
|
492
|
+
with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
|
493
|
+
analyze_type_info(typ)
|
494
|
+
|
495
|
+
|
496
|
+
def test_invalid_list_kind() -> None:
|
497
|
+
typ = Annotated[List[int], TypeKind("Struct")]
|
498
|
+
with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"):
|
499
|
+
analyze_type_info(typ)
|
500
|
+
|
501
|
+
|
502
|
+
def test_unsupported_type() -> None:
|
503
|
+
typ = set
|
504
|
+
with pytest.raises(ValueError, match="type unsupported yet: <class 'set'>"):
|
505
|
+
analyze_type_info(typ)
|
cocoindex/typing.py
CHANGED
@@ -9,14 +9,16 @@ from typing import (
|
|
9
9
|
Annotated,
|
10
10
|
NamedTuple,
|
11
11
|
Any,
|
12
|
+
KeysView,
|
12
13
|
TypeVar,
|
13
14
|
TYPE_CHECKING,
|
14
15
|
overload,
|
15
|
-
Sequence,
|
16
16
|
Generic,
|
17
17
|
Literal,
|
18
18
|
Protocol,
|
19
19
|
)
|
20
|
+
import numpy as np
|
21
|
+
from numpy.typing import NDArray
|
20
22
|
|
21
23
|
|
22
24
|
class VectorInfo(NamedTuple):
|
@@ -47,10 +49,10 @@ OffsetDateTime = Annotated[datetime.datetime, TypeKind("OffsetDateTime")]
|
|
47
49
|
|
48
50
|
if TYPE_CHECKING:
|
49
51
|
T_co = TypeVar("T_co", covariant=True)
|
50
|
-
Dim_co = TypeVar("Dim_co", bound=int, covariant=True)
|
52
|
+
Dim_co = TypeVar("Dim_co", bound=int | None, covariant=True, default=None)
|
51
53
|
|
52
54
|
class Vector(Protocol, Generic[T_co, Dim_co]):
|
53
|
-
"""Vector[T, Dim] is a special typing alias for
|
55
|
+
"""Vector[T, Dim] is a special typing alias for an NDArray[T] with optional dimension info"""
|
54
56
|
|
55
57
|
def __getitem__(self, index: int) -> T_co: ...
|
56
58
|
def __len__(self) -> int: ...
|
@@ -58,25 +60,34 @@ if TYPE_CHECKING:
|
|
58
60
|
else:
|
59
61
|
|
60
62
|
class Vector: # type: ignore[unreachable]
|
61
|
-
"""A special typing alias for
|
63
|
+
"""A special typing alias for an NDArray[T] with optional dimension info"""
|
62
64
|
|
63
65
|
def __class_getitem__(self, params):
|
64
66
|
if not isinstance(params, tuple):
|
65
|
-
#
|
66
|
-
|
67
|
-
|
67
|
+
# No dimension provided, e.g., Vector[np.float32]
|
68
|
+
dtype = params
|
69
|
+
# Use NDArray for supported numeric dtypes, else list
|
70
|
+
if DtypeRegistry.get_by_dtype(dtype) is not None:
|
71
|
+
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
|
72
|
+
return Annotated[list[dtype], VectorInfo(dim=None)]
|
68
73
|
else:
|
69
|
-
# Element type and dimension provided
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
+
# Element type and dimension provided, e.g., Vector[np.float32, Literal[3]]
|
75
|
+
dtype, dim_literal = params
|
76
|
+
# Extract the literal value
|
77
|
+
dim_val = (
|
78
|
+
typing.get_args(dim_literal)[0]
|
79
|
+
if typing.get_origin(dim_literal) is Literal
|
80
|
+
else None
|
81
|
+
)
|
82
|
+
if DtypeRegistry.get_by_dtype(dtype) is not None:
|
83
|
+
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
|
84
|
+
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
|
74
85
|
|
75
86
|
|
76
87
|
TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
|
77
88
|
KEY_FIELD_NAME: str = "_key"
|
78
89
|
|
79
|
-
ElementType = type | tuple[type, type]
|
90
|
+
ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
|
80
91
|
|
81
92
|
|
82
93
|
def is_namedtuple_type(t: type) -> bool:
|
@@ -89,6 +100,43 @@ def _is_struct_type(t: ElementType | None) -> bool:
|
|
89
100
|
)
|
90
101
|
|
91
102
|
|
103
|
+
class DtypeInfo:
|
104
|
+
"""Metadata for a NumPy dtype."""
|
105
|
+
|
106
|
+
def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
|
107
|
+
self.numpy_dtype = numpy_dtype
|
108
|
+
self.kind = kind
|
109
|
+
self.python_type = python_type
|
110
|
+
self.annotated_type = Annotated[python_type, TypeKind(kind)]
|
111
|
+
|
112
|
+
|
113
|
+
class DtypeRegistry:
|
114
|
+
"""
|
115
|
+
Registry for NumPy dtypes used in CocoIndex.
|
116
|
+
Provides mappings from NumPy dtypes to CocoIndex's type representation.
|
117
|
+
"""
|
118
|
+
|
119
|
+
_mappings: dict[type, DtypeInfo] = {
|
120
|
+
np.float32: DtypeInfo(np.float32, "Float32", float),
|
121
|
+
np.float64: DtypeInfo(np.float64, "Float64", float),
|
122
|
+
np.int64: DtypeInfo(np.int64, "Int64", int),
|
123
|
+
}
|
124
|
+
|
125
|
+
@classmethod
|
126
|
+
def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
|
127
|
+
"""Get DtypeInfo by NumPy dtype."""
|
128
|
+
if dtype is Any:
|
129
|
+
raise TypeError(
|
130
|
+
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
|
131
|
+
)
|
132
|
+
return cls._mappings.get(dtype)
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def supported_dtypes() -> KeysView[type]:
|
136
|
+
"""Get a list of supported NumPy dtypes."""
|
137
|
+
return DtypeRegistry._mappings.keys()
|
138
|
+
|
139
|
+
|
92
140
|
@dataclasses.dataclass
|
93
141
|
class AnalyzedTypeInfo:
|
94
142
|
"""
|
@@ -101,6 +149,9 @@ class AnalyzedTypeInfo:
|
|
101
149
|
|
102
150
|
key_type: type | None # For element of KTable
|
103
151
|
struct_type: type | None # For Struct, a dataclass or namedtuple
|
152
|
+
np_number_type: (
|
153
|
+
type | None
|
154
|
+
) # NumPy dtype for the element type, if represented by numpy.ndarray or a NumPy scalar
|
104
155
|
|
105
156
|
attrs: dict[str, Any] | None
|
106
157
|
nullable: bool = False
|
@@ -155,6 +206,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
155
206
|
struct_type: type | None = None
|
156
207
|
elem_type: ElementType | None = None
|
157
208
|
key_type: type | None = None
|
209
|
+
np_number_type: type | None = None
|
158
210
|
if _is_struct_type(t):
|
159
211
|
struct_type = t
|
160
212
|
|
@@ -179,12 +231,35 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
179
231
|
vector_info = VectorInfo(dim=None)
|
180
232
|
elif not (kind == "Vector" or kind in TABLE_TYPES):
|
181
233
|
raise ValueError(f"Unexpected type kind for list: {kind}")
|
234
|
+
elif base_type is np.ndarray:
|
235
|
+
kind = "Vector"
|
236
|
+
args = typing.get_args(t)
|
237
|
+
_, dtype_spec = args
|
238
|
+
|
239
|
+
dtype_args = typing.get_args(dtype_spec)
|
240
|
+
if not dtype_args:
|
241
|
+
raise ValueError("Invalid dtype specification for NDArray")
|
242
|
+
|
243
|
+
np_number_type = dtype_args[0]
|
244
|
+
dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
|
245
|
+
if dtype_info is None:
|
246
|
+
raise ValueError(
|
247
|
+
f"Unsupported numpy dtype for NDArray: {np_number_type}. "
|
248
|
+
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
|
249
|
+
)
|
250
|
+
elem_type = dtype_info.annotated_type
|
251
|
+
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
|
252
|
+
|
182
253
|
elif base_type is collections.abc.Mapping or base_type is dict:
|
183
254
|
args = typing.get_args(t)
|
184
255
|
elem_type = (args[0], args[1])
|
185
256
|
kind = "KTable"
|
186
257
|
elif kind is None:
|
187
|
-
|
258
|
+
dtype_info = DtypeRegistry.get_by_dtype(t)
|
259
|
+
if dtype_info is not None:
|
260
|
+
kind = dtype_info.kind
|
261
|
+
np_number_type = dtype_info.numpy_dtype
|
262
|
+
elif t is bytes:
|
188
263
|
kind = "Bytes"
|
189
264
|
elif t is str:
|
190
265
|
kind = "Str"
|
@@ -213,6 +288,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
213
288
|
elem_type=elem_type,
|
214
289
|
key_type=key_type,
|
215
290
|
struct_type=struct_type,
|
291
|
+
np_number_type=np_number_type,
|
216
292
|
attrs=attrs,
|
217
293
|
nullable=nullable,
|
218
294
|
)
|
@@ -265,9 +341,8 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
265
341
|
raise ValueError("Vector type must have a vector info")
|
266
342
|
if type_info.elem_type is None:
|
267
343
|
raise ValueError("Vector type must have an element type")
|
268
|
-
|
269
|
-
|
270
|
-
)
|
344
|
+
elem_type_info = analyze_type_info(type_info.elem_type)
|
345
|
+
encoded_type["element_type"] = _encode_type(elem_type_info)
|
271
346
|
encoded_type["dimension"] = type_info.vector_info.dim
|
272
347
|
|
273
348
|
elif type_info.kind in TABLE_TYPES:
|