cocoindex 0.1.52__cp311-cp311-win_amd64.whl → 0.1.54__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/_engine.cp311-win_amd64.pyd +0 -0
- cocoindex/cli.py +6 -6
- cocoindex/convert.py +93 -46
- cocoindex/flow.py +3 -2
- cocoindex/functions.py +10 -0
- cocoindex/llm.py +3 -0
- cocoindex/tests/__init__.py +0 -1
- cocoindex/tests/test_convert.py +289 -58
- cocoindex/tests/test_typing.py +115 -77
- cocoindex/typing.py +76 -64
- {cocoindex-0.1.52.dist-info → cocoindex-0.1.54.dist-info}/METADATA +11 -10
- cocoindex-0.1.54.dist-info/RECORD +28 -0
- cocoindex-0.1.52.dist-info/RECORD +0 -28
- {cocoindex-0.1.52.dist-info → cocoindex-0.1.54.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.52.dist-info → cocoindex-0.1.54.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.52.dist-info → cocoindex-0.1.54.dist-info}/licenses/LICENSE +0 -0
cocoindex/tests/test_typing.py
CHANGED
@@ -1,30 +1,21 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import datetime
|
3
3
|
import uuid
|
4
|
-
from
|
5
|
-
|
6
|
-
|
7
|
-
Dict,
|
8
|
-
Literal,
|
9
|
-
Any,
|
10
|
-
get_args,
|
11
|
-
NamedTuple,
|
12
|
-
)
|
13
|
-
from collections.abc import Sequence, Mapping
|
14
|
-
import pytest
|
4
|
+
from collections.abc import Mapping, Sequence
|
5
|
+
from typing import Annotated, Any, Literal, NamedTuple, get_args, get_origin
|
6
|
+
|
15
7
|
import numpy as np
|
8
|
+
import pytest
|
16
9
|
from numpy.typing import NDArray
|
17
10
|
|
18
11
|
from cocoindex.typing import (
|
19
|
-
|
12
|
+
AnalyzedTypeInfo,
|
13
|
+
TypeAttr,
|
14
|
+
TypeKind,
|
20
15
|
Vector,
|
21
16
|
VectorInfo,
|
22
|
-
|
23
|
-
TypeAttr,
|
24
|
-
Float32,
|
25
|
-
Float64,
|
17
|
+
analyze_type_info,
|
26
18
|
encode_enriched_type,
|
27
|
-
AnalyzedTypeInfo,
|
28
19
|
)
|
29
20
|
|
30
21
|
|
@@ -42,61 +33,57 @@ class SimpleNamedTuple(NamedTuple):
|
|
42
33
|
def test_ndarray_float32_no_dim() -> None:
|
43
34
|
typ = NDArray[np.float32]
|
44
35
|
result = analyze_type_info(typ)
|
45
|
-
assert result ==
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
)
|
36
|
+
assert result.kind == "Vector"
|
37
|
+
assert result.vector_info == VectorInfo(dim=None)
|
38
|
+
assert result.elem_type == np.float32
|
39
|
+
assert result.key_type is None
|
40
|
+
assert result.struct_type is None
|
41
|
+
assert result.nullable is False
|
42
|
+
assert result.np_number_type is not None
|
43
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
44
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
|
55
45
|
|
56
46
|
|
57
47
|
def test_vector_float32_no_dim() -> None:
|
58
48
|
typ = Vector[np.float32]
|
59
49
|
result = analyze_type_info(typ)
|
60
|
-
assert result ==
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
)
|
50
|
+
assert result.kind == "Vector"
|
51
|
+
assert result.vector_info == VectorInfo(dim=None)
|
52
|
+
assert result.elem_type == np.float32
|
53
|
+
assert result.key_type is None
|
54
|
+
assert result.struct_type is None
|
55
|
+
assert result.nullable is False
|
56
|
+
assert result.np_number_type is not None
|
57
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
58
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
|
70
59
|
|
71
60
|
|
72
61
|
def test_ndarray_float64_with_dim() -> None:
|
73
62
|
typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
|
74
63
|
result = analyze_type_info(typ)
|
75
|
-
assert result ==
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
)
|
64
|
+
assert result.kind == "Vector"
|
65
|
+
assert result.vector_info == VectorInfo(dim=128)
|
66
|
+
assert result.elem_type == np.float64
|
67
|
+
assert result.key_type is None
|
68
|
+
assert result.struct_type is None
|
69
|
+
assert result.nullable is False
|
70
|
+
assert result.np_number_type is not None
|
71
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
72
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.float64]
|
85
73
|
|
86
74
|
|
87
75
|
def test_vector_float32_with_dim() -> None:
|
88
76
|
typ = Vector[np.float32, Literal[384]]
|
89
77
|
result = analyze_type_info(typ)
|
90
|
-
assert result ==
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
)
|
78
|
+
assert result.kind == "Vector"
|
79
|
+
assert result.vector_info == VectorInfo(dim=384)
|
80
|
+
assert result.elem_type == np.float32
|
81
|
+
assert result.key_type is None
|
82
|
+
assert result.struct_type is None
|
83
|
+
assert result.nullable is False
|
84
|
+
assert result.np_number_type is not None
|
85
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
86
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
|
100
87
|
|
101
88
|
|
102
89
|
def test_ndarray_int64_no_dim() -> None:
|
@@ -104,30 +91,49 @@ def test_ndarray_int64_no_dim() -> None:
|
|
104
91
|
result = analyze_type_info(typ)
|
105
92
|
assert result.kind == "Vector"
|
106
93
|
assert result.vector_info == VectorInfo(dim=None)
|
107
|
-
assert
|
108
|
-
assert
|
94
|
+
assert result.elem_type == np.int64
|
95
|
+
assert result.nullable is False
|
96
|
+
assert result.np_number_type is not None
|
97
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
98
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.int64]
|
109
99
|
|
110
100
|
|
111
101
|
def test_nullable_ndarray() -> None:
|
112
102
|
typ = NDArray[np.float32] | None
|
113
103
|
result = analyze_type_info(typ)
|
114
|
-
assert result ==
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
104
|
+
assert result.kind == "Vector"
|
105
|
+
assert result.vector_info == VectorInfo(dim=None)
|
106
|
+
assert result.elem_type == np.float32
|
107
|
+
assert result.key_type is None
|
108
|
+
assert result.struct_type is None
|
109
|
+
assert result.nullable is True
|
110
|
+
assert result.np_number_type is not None
|
111
|
+
assert get_origin(result.np_number_type) == np.ndarray
|
112
|
+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
|
113
|
+
|
114
|
+
|
115
|
+
def test_scalar_numpy_types() -> None:
|
116
|
+
for np_type, expected_kind in [
|
117
|
+
(np.int64, "Int64"),
|
118
|
+
(np.float32, "Float32"),
|
119
|
+
(np.float64, "Float64"),
|
120
|
+
]:
|
121
|
+
type_info = analyze_type_info(np_type)
|
122
|
+
assert type_info.kind == expected_kind, (
|
123
|
+
f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
|
124
|
+
)
|
125
|
+
assert type_info.np_number_type == np_type, (
|
126
|
+
f"Expected {np_type}, got {type_info.np_number_type}"
|
127
|
+
)
|
128
|
+
assert type_info.elem_type is None
|
129
|
+
assert type_info.vector_info is None
|
124
130
|
|
125
131
|
|
126
132
|
def test_vector_str() -> None:
|
127
133
|
typ = Vector[str]
|
128
134
|
result = analyze_type_info(typ)
|
129
135
|
assert result.kind == "Vector"
|
130
|
-
assert result.elem_type
|
136
|
+
assert result.elem_type is str
|
131
137
|
assert result.vector_info == VectorInfo(dim=None)
|
132
138
|
|
133
139
|
|
@@ -143,7 +149,7 @@ def test_non_numpy_vector() -> None:
|
|
143
149
|
typ = Vector[float, Literal[3]]
|
144
150
|
result = analyze_type_info(typ)
|
145
151
|
assert result.kind == "Vector"
|
146
|
-
assert result.elem_type
|
152
|
+
assert result.elem_type is float
|
147
153
|
assert result.vector_info == VectorInfo(dim=3)
|
148
154
|
|
149
155
|
|
@@ -156,10 +162,11 @@ def test_ndarray_any_dtype() -> None:
|
|
156
162
|
|
157
163
|
|
158
164
|
def test_list_of_primitives() -> None:
|
159
|
-
typ =
|
165
|
+
typ = list[str]
|
160
166
|
result = analyze_type_info(typ)
|
161
167
|
assert result == AnalyzedTypeInfo(
|
162
168
|
kind="Vector",
|
169
|
+
core_type=list[str],
|
163
170
|
vector_info=VectorInfo(dim=None),
|
164
171
|
elem_type=str,
|
165
172
|
key_type=None,
|
@@ -171,10 +178,11 @@ def test_list_of_primitives() -> None:
|
|
171
178
|
|
172
179
|
|
173
180
|
def test_list_of_structs() -> None:
|
174
|
-
typ =
|
181
|
+
typ = list[SimpleDataclass]
|
175
182
|
result = analyze_type_info(typ)
|
176
183
|
assert result == AnalyzedTypeInfo(
|
177
184
|
kind="LTable",
|
185
|
+
core_type=list[SimpleDataclass],
|
178
186
|
vector_info=None,
|
179
187
|
elem_type=SimpleDataclass,
|
180
188
|
key_type=None,
|
@@ -190,6 +198,7 @@ def test_sequence_of_int() -> None:
|
|
190
198
|
result = analyze_type_info(typ)
|
191
199
|
assert result == AnalyzedTypeInfo(
|
192
200
|
kind="Vector",
|
201
|
+
core_type=Sequence[int],
|
193
202
|
vector_info=VectorInfo(dim=None),
|
194
203
|
elem_type=int,
|
195
204
|
key_type=None,
|
@@ -201,10 +210,11 @@ def test_sequence_of_int() -> None:
|
|
201
210
|
|
202
211
|
|
203
212
|
def test_list_with_vector_info() -> None:
|
204
|
-
typ = Annotated[
|
213
|
+
typ = Annotated[list[int], VectorInfo(dim=5)]
|
205
214
|
result = analyze_type_info(typ)
|
206
215
|
assert result == AnalyzedTypeInfo(
|
207
216
|
kind="Vector",
|
217
|
+
core_type=list[int],
|
208
218
|
vector_info=VectorInfo(dim=5),
|
209
219
|
elem_type=int,
|
210
220
|
key_type=None,
|
@@ -216,10 +226,11 @@ def test_list_with_vector_info() -> None:
|
|
216
226
|
|
217
227
|
|
218
228
|
def test_dict_str_int() -> None:
|
219
|
-
typ =
|
229
|
+
typ = dict[str, int]
|
220
230
|
result = analyze_type_info(typ)
|
221
231
|
assert result == AnalyzedTypeInfo(
|
222
232
|
kind="KTable",
|
233
|
+
core_type=dict[str, int],
|
223
234
|
vector_info=None,
|
224
235
|
elem_type=(str, int),
|
225
236
|
key_type=None,
|
@@ -235,6 +246,7 @@ def test_mapping_str_dataclass() -> None:
|
|
235
246
|
result = analyze_type_info(typ)
|
236
247
|
assert result == AnalyzedTypeInfo(
|
237
248
|
kind="KTable",
|
249
|
+
core_type=Mapping[str, SimpleDataclass],
|
238
250
|
vector_info=None,
|
239
251
|
elem_type=(str, SimpleDataclass),
|
240
252
|
key_type=None,
|
@@ -250,6 +262,7 @@ def test_dataclass() -> None:
|
|
250
262
|
result = analyze_type_info(typ)
|
251
263
|
assert result == AnalyzedTypeInfo(
|
252
264
|
kind="Struct",
|
265
|
+
core_type=SimpleDataclass,
|
253
266
|
vector_info=None,
|
254
267
|
elem_type=None,
|
255
268
|
key_type=None,
|
@@ -265,6 +278,7 @@ def test_named_tuple() -> None:
|
|
265
278
|
result = analyze_type_info(typ)
|
266
279
|
assert result == AnalyzedTypeInfo(
|
267
280
|
kind="Struct",
|
281
|
+
core_type=SimpleNamedTuple,
|
268
282
|
vector_info=None,
|
269
283
|
elem_type=None,
|
270
284
|
key_type=None,
|
@@ -280,6 +294,7 @@ def test_tuple_key_value() -> None:
|
|
280
294
|
result = analyze_type_info(typ)
|
281
295
|
assert result == AnalyzedTypeInfo(
|
282
296
|
kind="Int64",
|
297
|
+
core_type=int,
|
283
298
|
vector_info=None,
|
284
299
|
elem_type=None,
|
285
300
|
key_type=str,
|
@@ -295,6 +310,7 @@ def test_str() -> None:
|
|
295
310
|
result = analyze_type_info(typ)
|
296
311
|
assert result == AnalyzedTypeInfo(
|
297
312
|
kind="Str",
|
313
|
+
core_type=str,
|
298
314
|
vector_info=None,
|
299
315
|
elem_type=None,
|
300
316
|
key_type=None,
|
@@ -310,6 +326,7 @@ def test_bool() -> None:
|
|
310
326
|
result = analyze_type_info(typ)
|
311
327
|
assert result == AnalyzedTypeInfo(
|
312
328
|
kind="Bool",
|
329
|
+
core_type=bool,
|
313
330
|
vector_info=None,
|
314
331
|
elem_type=None,
|
315
332
|
key_type=None,
|
@@ -325,6 +342,7 @@ def test_bytes() -> None:
|
|
325
342
|
result = analyze_type_info(typ)
|
326
343
|
assert result == AnalyzedTypeInfo(
|
327
344
|
kind="Bytes",
|
345
|
+
core_type=bytes,
|
328
346
|
vector_info=None,
|
329
347
|
elem_type=None,
|
330
348
|
key_type=None,
|
@@ -340,6 +358,7 @@ def test_uuid() -> None:
|
|
340
358
|
result = analyze_type_info(typ)
|
341
359
|
assert result == AnalyzedTypeInfo(
|
342
360
|
kind="Uuid",
|
361
|
+
core_type=uuid.UUID,
|
343
362
|
vector_info=None,
|
344
363
|
elem_type=None,
|
345
364
|
key_type=None,
|
@@ -355,6 +374,7 @@ def test_date() -> None:
|
|
355
374
|
result = analyze_type_info(typ)
|
356
375
|
assert result == AnalyzedTypeInfo(
|
357
376
|
kind="Date",
|
377
|
+
core_type=datetime.date,
|
358
378
|
vector_info=None,
|
359
379
|
elem_type=None,
|
360
380
|
key_type=None,
|
@@ -370,6 +390,7 @@ def test_time() -> None:
|
|
370
390
|
result = analyze_type_info(typ)
|
371
391
|
assert result == AnalyzedTypeInfo(
|
372
392
|
kind="Time",
|
393
|
+
core_type=datetime.time,
|
373
394
|
vector_info=None,
|
374
395
|
elem_type=None,
|
375
396
|
key_type=None,
|
@@ -385,6 +406,7 @@ def test_timedelta() -> None:
|
|
385
406
|
result = analyze_type_info(typ)
|
386
407
|
assert result == AnalyzedTypeInfo(
|
387
408
|
kind="TimeDelta",
|
409
|
+
core_type=datetime.timedelta,
|
388
410
|
vector_info=None,
|
389
411
|
elem_type=None,
|
390
412
|
key_type=None,
|
@@ -400,6 +422,7 @@ def test_float() -> None:
|
|
400
422
|
result = analyze_type_info(typ)
|
401
423
|
assert result == AnalyzedTypeInfo(
|
402
424
|
kind="Float64",
|
425
|
+
core_type=float,
|
403
426
|
vector_info=None,
|
404
427
|
elem_type=None,
|
405
428
|
key_type=None,
|
@@ -415,6 +438,7 @@ def test_int() -> None:
|
|
415
438
|
result = analyze_type_info(typ)
|
416
439
|
assert result == AnalyzedTypeInfo(
|
417
440
|
kind="Int64",
|
441
|
+
core_type=int,
|
418
442
|
vector_info=None,
|
419
443
|
elem_type=None,
|
420
444
|
key_type=None,
|
@@ -430,6 +454,7 @@ def test_type_with_attributes() -> None:
|
|
430
454
|
result = analyze_type_info(typ)
|
431
455
|
assert result == AnalyzedTypeInfo(
|
432
456
|
kind="Str",
|
457
|
+
core_type=str,
|
433
458
|
vector_info=None,
|
434
459
|
elem_type=None,
|
435
460
|
key_type=None,
|
@@ -466,7 +491,7 @@ def test_encode_enriched_type_vector() -> None:
|
|
466
491
|
|
467
492
|
|
468
493
|
def test_encode_enriched_type_ltable() -> None:
|
469
|
-
typ =
|
494
|
+
typ = list[SimpleDataclass]
|
470
495
|
result = encode_enriched_type(typ)
|
471
496
|
assert result["type"]["kind"] == "LTable"
|
472
497
|
assert result["type"]["row"]["kind"] == "Struct"
|
@@ -487,6 +512,19 @@ def test_encode_enriched_type_nullable() -> None:
|
|
487
512
|
assert result["nullable"] is True
|
488
513
|
|
489
514
|
|
515
|
+
def test_encode_scalar_numpy_types_schema() -> None:
|
516
|
+
for np_type, expected_kind in [
|
517
|
+
(np.int64, "Int64"),
|
518
|
+
(np.float32, "Float32"),
|
519
|
+
(np.float64, "Float64"),
|
520
|
+
]:
|
521
|
+
schema = encode_enriched_type(np_type)
|
522
|
+
assert schema["type"]["kind"] == expected_kind, (
|
523
|
+
f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
|
524
|
+
)
|
525
|
+
assert not schema.get("nullable", False)
|
526
|
+
|
527
|
+
|
490
528
|
def test_invalid_struct_kind() -> None:
|
491
529
|
typ = Annotated[SimpleDataclass, TypeKind("Vector")]
|
492
530
|
with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
|
@@ -494,7 +532,7 @@ def test_invalid_struct_kind() -> None:
|
|
494
532
|
|
495
533
|
|
496
534
|
def test_invalid_list_kind() -> None:
|
497
|
-
typ = Annotated[
|
535
|
+
typ = Annotated[list[int], TypeKind("Struct")]
|
498
536
|
with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"):
|
499
537
|
analyze_type_info(typ)
|
500
538
|
|
cocoindex/typing.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
|
-
import typing
|
2
1
|
import collections
|
3
2
|
import dataclasses
|
4
3
|
import datetime
|
5
|
-
import types
|
6
4
|
import inspect
|
5
|
+
import types
|
6
|
+
import typing
|
7
7
|
import uuid
|
8
8
|
from typing import (
|
9
|
+
TYPE_CHECKING,
|
9
10
|
Annotated,
|
10
|
-
NamedTuple,
|
11
11
|
Any,
|
12
|
-
KeysView,
|
13
|
-
TypeVar,
|
14
|
-
TYPE_CHECKING,
|
15
|
-
overload,
|
16
12
|
Generic,
|
17
13
|
Literal,
|
14
|
+
NamedTuple,
|
18
15
|
Protocol,
|
16
|
+
TypeVar,
|
17
|
+
overload,
|
19
18
|
)
|
19
|
+
|
20
20
|
import numpy as np
|
21
21
|
from numpy.typing import NDArray
|
22
22
|
|
@@ -67,7 +67,7 @@ else:
|
|
67
67
|
# No dimension provided, e.g., Vector[np.float32]
|
68
68
|
dtype = params
|
69
69
|
# Use NDArray for supported numeric dtypes, else list
|
70
|
-
if
|
70
|
+
if dtype in DtypeRegistry._DTYPE_TO_KIND:
|
71
71
|
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
|
72
72
|
return Annotated[list[dtype], VectorInfo(dim=None)]
|
73
73
|
else:
|
@@ -79,7 +79,7 @@ else:
|
|
79
79
|
if typing.get_origin(dim_literal) is Literal
|
80
80
|
else None
|
81
81
|
)
|
82
|
-
if
|
82
|
+
if dtype in DtypeRegistry._DTYPE_TO_KIND:
|
83
83
|
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
|
84
84
|
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
|
85
85
|
|
@@ -90,6 +90,19 @@ KEY_FIELD_NAME: str = "_key"
|
|
90
90
|
ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
|
91
91
|
|
92
92
|
|
93
|
+
def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
|
94
|
+
args = typing.get_args(ndarray_type)
|
95
|
+
_, dtype_spec = args
|
96
|
+
dtype_args = typing.get_args(dtype_spec)
|
97
|
+
if not dtype_args:
|
98
|
+
raise ValueError(f"Invalid dtype specification: {dtype_spec}")
|
99
|
+
return dtype_args[0]
|
100
|
+
|
101
|
+
|
102
|
+
def is_numpy_number_type(t: type) -> bool:
|
103
|
+
return isinstance(t, type) and issubclass(t, np.number)
|
104
|
+
|
105
|
+
|
93
106
|
def is_namedtuple_type(t: type) -> bool:
|
94
107
|
return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields")
|
95
108
|
|
@@ -100,41 +113,34 @@ def _is_struct_type(t: ElementType | None) -> bool:
|
|
100
113
|
)
|
101
114
|
|
102
115
|
|
103
|
-
class DtypeInfo:
|
104
|
-
"""Metadata for a NumPy dtype."""
|
105
|
-
|
106
|
-
def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
|
107
|
-
self.numpy_dtype = numpy_dtype
|
108
|
-
self.kind = kind
|
109
|
-
self.python_type = python_type
|
110
|
-
self.annotated_type = Annotated[python_type, TypeKind(kind)]
|
111
|
-
|
112
|
-
|
113
116
|
class DtypeRegistry:
|
114
117
|
"""
|
115
118
|
Registry for NumPy dtypes used in CocoIndex.
|
116
|
-
|
119
|
+
Maps NumPy dtypes to their CocoIndex type kind.
|
117
120
|
"""
|
118
121
|
|
119
|
-
|
120
|
-
np.float32:
|
121
|
-
np.float64:
|
122
|
-
np.int64:
|
122
|
+
_DTYPE_TO_KIND: dict[ElementType, str] = {
|
123
|
+
np.float32: "Float32",
|
124
|
+
np.float64: "Float64",
|
125
|
+
np.int64: "Int64",
|
123
126
|
}
|
124
127
|
|
125
128
|
@classmethod
|
126
|
-
def
|
127
|
-
"""
|
129
|
+
def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str:
|
130
|
+
"""
|
131
|
+
Validate that the given dtype is supported, and get its CocoIndex kind by dtype.
|
132
|
+
"""
|
128
133
|
if dtype is Any:
|
129
134
|
raise TypeError(
|
130
135
|
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
|
131
136
|
)
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
137
|
+
kind = cls._DTYPE_TO_KIND.get(dtype)
|
138
|
+
if kind is None:
|
139
|
+
raise ValueError(
|
140
|
+
f"Unsupported NumPy dtype in NDArray: {dtype}. "
|
141
|
+
f"Supported dtypes: {cls._DTYPE_TO_KIND.keys()}"
|
142
|
+
)
|
143
|
+
return kind
|
138
144
|
|
139
145
|
|
140
146
|
@dataclasses.dataclass
|
@@ -144,6 +150,7 @@ class AnalyzedTypeInfo:
|
|
144
150
|
"""
|
145
151
|
|
146
152
|
kind: str
|
153
|
+
core_type: Any
|
147
154
|
vector_info: VectorInfo | None # For Vector
|
148
155
|
elem_type: ElementType | None # For Vector and Table
|
149
156
|
|
@@ -155,6 +162,7 @@ class AnalyzedTypeInfo:
|
|
155
162
|
|
156
163
|
attrs: dict[str, Any] | None
|
157
164
|
nullable: bool = False
|
165
|
+
union_variant_types: typing.List[ElementType] | None = None # For Union
|
158
166
|
|
159
167
|
|
160
168
|
def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
@@ -175,18 +183,6 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
175
183
|
if base_type is Annotated:
|
176
184
|
annotations = t.__metadata__
|
177
185
|
t = t.__origin__
|
178
|
-
elif base_type is types.UnionType:
|
179
|
-
possible_types = typing.get_args(t)
|
180
|
-
non_none_types = [
|
181
|
-
arg for arg in possible_types if arg not in (None, types.NoneType)
|
182
|
-
]
|
183
|
-
if len(non_none_types) != 1:
|
184
|
-
raise ValueError(
|
185
|
-
f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}"
|
186
|
-
)
|
187
|
-
t = non_none_types[0]
|
188
|
-
if len(possible_types) > 1:
|
189
|
-
nullable = True
|
190
186
|
else:
|
191
187
|
break
|
192
188
|
|
@@ -205,6 +201,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
205
201
|
|
206
202
|
struct_type: type | None = None
|
207
203
|
elem_type: ElementType | None = None
|
204
|
+
union_variant_types: typing.List[ElementType] | None = None
|
208
205
|
key_type: type | None = None
|
209
206
|
np_number_type: type | None = None
|
210
207
|
if _is_struct_type(t):
|
@@ -214,6 +211,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
214
211
|
kind = "Struct"
|
215
212
|
elif kind != "Struct":
|
216
213
|
raise ValueError(f"Unexpected type kind for struct: {kind}")
|
214
|
+
elif is_numpy_number_type(t):
|
215
|
+
np_number_type = t
|
216
|
+
kind = DtypeRegistry.validate_dtype_and_get_kind(t)
|
217
217
|
elif base_type is collections.abc.Sequence or base_type is list:
|
218
218
|
args = typing.get_args(t)
|
219
219
|
elem_type = args[0]
|
@@ -233,33 +233,35 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
233
233
|
raise ValueError(f"Unexpected type kind for list: {kind}")
|
234
234
|
elif base_type is np.ndarray:
|
235
235
|
kind = "Vector"
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
dtype_args = typing.get_args(dtype_spec)
|
240
|
-
if not dtype_args:
|
241
|
-
raise ValueError("Invalid dtype specification for NDArray")
|
242
|
-
|
243
|
-
np_number_type = dtype_args[0]
|
244
|
-
dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
|
245
|
-
if dtype_info is None:
|
246
|
-
raise ValueError(
|
247
|
-
f"Unsupported numpy dtype for NDArray: {np_number_type}. "
|
248
|
-
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
|
249
|
-
)
|
250
|
-
elem_type = dtype_info.annotated_type
|
236
|
+
np_number_type = t
|
237
|
+
elem_type = extract_ndarray_scalar_dtype(np_number_type)
|
238
|
+
_ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
|
251
239
|
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
|
252
240
|
|
253
241
|
elif base_type is collections.abc.Mapping or base_type is dict:
|
254
242
|
args = typing.get_args(t)
|
255
243
|
elem_type = (args[0], args[1])
|
256
244
|
kind = "KTable"
|
245
|
+
elif base_type is types.UnionType:
|
246
|
+
possible_types = typing.get_args(t)
|
247
|
+
non_none_types = [
|
248
|
+
arg for arg in possible_types if arg not in (None, types.NoneType)
|
249
|
+
]
|
250
|
+
|
251
|
+
if len(non_none_types) == 0:
|
252
|
+
return analyze_type_info(None)
|
253
|
+
|
254
|
+
nullable = len(non_none_types) < len(possible_types)
|
255
|
+
|
256
|
+
if len(non_none_types) == 1:
|
257
|
+
result = analyze_type_info(non_none_types[0])
|
258
|
+
result.nullable = nullable
|
259
|
+
return result
|
260
|
+
|
261
|
+
kind = "Union"
|
262
|
+
union_variant_types = non_none_types
|
257
263
|
elif kind is None:
|
258
|
-
|
259
|
-
if dtype_info is not None:
|
260
|
-
kind = dtype_info.kind
|
261
|
-
np_number_type = dtype_info.numpy_dtype
|
262
|
-
elif t is bytes:
|
264
|
+
if t is bytes:
|
263
265
|
kind = "Bytes"
|
264
266
|
elif t is str:
|
265
267
|
kind = "Str"
|
@@ -284,8 +286,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
284
286
|
|
285
287
|
return AnalyzedTypeInfo(
|
286
288
|
kind=kind,
|
289
|
+
core_type=t,
|
287
290
|
vector_info=vector_info,
|
288
291
|
elem_type=elem_type,
|
292
|
+
union_variant_types=union_variant_types,
|
289
293
|
key_type=key_type,
|
290
294
|
struct_type=struct_type,
|
291
295
|
np_number_type=np_number_type,
|
@@ -345,6 +349,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
345
349
|
encoded_type["element_type"] = _encode_type(elem_type_info)
|
346
350
|
encoded_type["dimension"] = type_info.vector_info.dim
|
347
351
|
|
352
|
+
elif type_info.kind == "Union":
|
353
|
+
if type_info.union_variant_types is None:
|
354
|
+
raise ValueError("Union type must have a variant type list")
|
355
|
+
encoded_type["types"] = [
|
356
|
+
_encode_type(analyze_type_info(typ))
|
357
|
+
for typ in type_info.union_variant_types
|
358
|
+
]
|
359
|
+
|
348
360
|
elif type_info.kind in TABLE_TYPES:
|
349
361
|
if type_info.elem_type is None:
|
350
362
|
raise ValueError(f"{type_info.kind} type must have an element type")
|