cocoindex 0.1.52__cp311-cp311-win_amd64.whl → 0.1.54__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,30 +1,21 @@
1
1
  import dataclasses
2
2
  import datetime
3
3
  import uuid
4
- from typing import (
5
- Annotated,
6
- List,
7
- Dict,
8
- Literal,
9
- Any,
10
- get_args,
11
- NamedTuple,
12
- )
13
- from collections.abc import Sequence, Mapping
14
- import pytest
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Annotated, Any, Literal, NamedTuple, get_args, get_origin
6
+
15
7
  import numpy as np
8
+ import pytest
16
9
  from numpy.typing import NDArray
17
10
 
18
11
  from cocoindex.typing import (
19
- analyze_type_info,
12
+ AnalyzedTypeInfo,
13
+ TypeAttr,
14
+ TypeKind,
20
15
  Vector,
21
16
  VectorInfo,
22
- TypeKind,
23
- TypeAttr,
24
- Float32,
25
- Float64,
17
+ analyze_type_info,
26
18
  encode_enriched_type,
27
- AnalyzedTypeInfo,
28
19
  )
29
20
 
30
21
 
@@ -42,61 +33,57 @@ class SimpleNamedTuple(NamedTuple):
42
33
  def test_ndarray_float32_no_dim() -> None:
43
34
  typ = NDArray[np.float32]
44
35
  result = analyze_type_info(typ)
45
- assert result == AnalyzedTypeInfo(
46
- kind="Vector",
47
- vector_info=VectorInfo(dim=None),
48
- elem_type=Float32,
49
- key_type=None,
50
- struct_type=None,
51
- np_number_type=np.float32,
52
- attrs=None,
53
- nullable=False,
54
- )
36
+ assert result.kind == "Vector"
37
+ assert result.vector_info == VectorInfo(dim=None)
38
+ assert result.elem_type == np.float32
39
+ assert result.key_type is None
40
+ assert result.struct_type is None
41
+ assert result.nullable is False
42
+ assert result.np_number_type is not None
43
+ assert get_origin(result.np_number_type) == np.ndarray
44
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
55
45
 
56
46
 
57
47
  def test_vector_float32_no_dim() -> None:
58
48
  typ = Vector[np.float32]
59
49
  result = analyze_type_info(typ)
60
- assert result == AnalyzedTypeInfo(
61
- kind="Vector",
62
- vector_info=VectorInfo(dim=None),
63
- elem_type=Float32,
64
- key_type=None,
65
- struct_type=None,
66
- np_number_type=np.float32,
67
- attrs=None,
68
- nullable=False,
69
- )
50
+ assert result.kind == "Vector"
51
+ assert result.vector_info == VectorInfo(dim=None)
52
+ assert result.elem_type == np.float32
53
+ assert result.key_type is None
54
+ assert result.struct_type is None
55
+ assert result.nullable is False
56
+ assert result.np_number_type is not None
57
+ assert get_origin(result.np_number_type) == np.ndarray
58
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
70
59
 
71
60
 
72
61
  def test_ndarray_float64_with_dim() -> None:
73
62
  typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
74
63
  result = analyze_type_info(typ)
75
- assert result == AnalyzedTypeInfo(
76
- kind="Vector",
77
- vector_info=VectorInfo(dim=128),
78
- elem_type=Float64,
79
- key_type=None,
80
- struct_type=None,
81
- np_number_type=np.float64,
82
- attrs=None,
83
- nullable=False,
84
- )
64
+ assert result.kind == "Vector"
65
+ assert result.vector_info == VectorInfo(dim=128)
66
+ assert result.elem_type == np.float64
67
+ assert result.key_type is None
68
+ assert result.struct_type is None
69
+ assert result.nullable is False
70
+ assert result.np_number_type is not None
71
+ assert get_origin(result.np_number_type) == np.ndarray
72
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float64]
85
73
 
86
74
 
87
75
  def test_vector_float32_with_dim() -> None:
88
76
  typ = Vector[np.float32, Literal[384]]
89
77
  result = analyze_type_info(typ)
90
- assert result == AnalyzedTypeInfo(
91
- kind="Vector",
92
- vector_info=VectorInfo(dim=384),
93
- elem_type=Float32,
94
- key_type=None,
95
- struct_type=None,
96
- np_number_type=np.float32,
97
- attrs=None,
98
- nullable=False,
99
- )
78
+ assert result.kind == "Vector"
79
+ assert result.vector_info == VectorInfo(dim=384)
80
+ assert result.elem_type == np.float32
81
+ assert result.key_type is None
82
+ assert result.struct_type is None
83
+ assert result.nullable is False
84
+ assert result.np_number_type is not None
85
+ assert get_origin(result.np_number_type) == np.ndarray
86
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
100
87
 
101
88
 
102
89
  def test_ndarray_int64_no_dim() -> None:
@@ -104,30 +91,49 @@ def test_ndarray_int64_no_dim() -> None:
104
91
  result = analyze_type_info(typ)
105
92
  assert result.kind == "Vector"
106
93
  assert result.vector_info == VectorInfo(dim=None)
107
- assert get_args(result.elem_type) == (int, TypeKind("Int64"))
108
- assert not result.nullable
94
+ assert result.elem_type == np.int64
95
+ assert result.nullable is False
96
+ assert result.np_number_type is not None
97
+ assert get_origin(result.np_number_type) == np.ndarray
98
+ assert get_args(result.np_number_type)[1] == np.dtype[np.int64]
109
99
 
110
100
 
111
101
  def test_nullable_ndarray() -> None:
112
102
  typ = NDArray[np.float32] | None
113
103
  result = analyze_type_info(typ)
114
- assert result == AnalyzedTypeInfo(
115
- kind="Vector",
116
- vector_info=VectorInfo(dim=None),
117
- elem_type=Float32,
118
- key_type=None,
119
- struct_type=None,
120
- np_number_type=np.float32,
121
- attrs=None,
122
- nullable=True,
123
- )
104
+ assert result.kind == "Vector"
105
+ assert result.vector_info == VectorInfo(dim=None)
106
+ assert result.elem_type == np.float32
107
+ assert result.key_type is None
108
+ assert result.struct_type is None
109
+ assert result.nullable is True
110
+ assert result.np_number_type is not None
111
+ assert get_origin(result.np_number_type) == np.ndarray
112
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
113
+
114
+
115
+ def test_scalar_numpy_types() -> None:
116
+ for np_type, expected_kind in [
117
+ (np.int64, "Int64"),
118
+ (np.float32, "Float32"),
119
+ (np.float64, "Float64"),
120
+ ]:
121
+ type_info = analyze_type_info(np_type)
122
+ assert type_info.kind == expected_kind, (
123
+ f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
124
+ )
125
+ assert type_info.np_number_type == np_type, (
126
+ f"Expected {np_type}, got {type_info.np_number_type}"
127
+ )
128
+ assert type_info.elem_type is None
129
+ assert type_info.vector_info is None
124
130
 
125
131
 
126
132
  def test_vector_str() -> None:
127
133
  typ = Vector[str]
128
134
  result = analyze_type_info(typ)
129
135
  assert result.kind == "Vector"
130
- assert result.elem_type == str
136
+ assert result.elem_type is str
131
137
  assert result.vector_info == VectorInfo(dim=None)
132
138
 
133
139
 
@@ -143,7 +149,7 @@ def test_non_numpy_vector() -> None:
143
149
  typ = Vector[float, Literal[3]]
144
150
  result = analyze_type_info(typ)
145
151
  assert result.kind == "Vector"
146
- assert result.elem_type == float
152
+ assert result.elem_type is float
147
153
  assert result.vector_info == VectorInfo(dim=3)
148
154
 
149
155
 
@@ -156,10 +162,11 @@ def test_ndarray_any_dtype() -> None:
156
162
 
157
163
 
158
164
  def test_list_of_primitives() -> None:
159
- typ = List[str]
165
+ typ = list[str]
160
166
  result = analyze_type_info(typ)
161
167
  assert result == AnalyzedTypeInfo(
162
168
  kind="Vector",
169
+ core_type=list[str],
163
170
  vector_info=VectorInfo(dim=None),
164
171
  elem_type=str,
165
172
  key_type=None,
@@ -171,10 +178,11 @@ def test_list_of_primitives() -> None:
171
178
 
172
179
 
173
180
  def test_list_of_structs() -> None:
174
- typ = List[SimpleDataclass]
181
+ typ = list[SimpleDataclass]
175
182
  result = analyze_type_info(typ)
176
183
  assert result == AnalyzedTypeInfo(
177
184
  kind="LTable",
185
+ core_type=list[SimpleDataclass],
178
186
  vector_info=None,
179
187
  elem_type=SimpleDataclass,
180
188
  key_type=None,
@@ -190,6 +198,7 @@ def test_sequence_of_int() -> None:
190
198
  result = analyze_type_info(typ)
191
199
  assert result == AnalyzedTypeInfo(
192
200
  kind="Vector",
201
+ core_type=Sequence[int],
193
202
  vector_info=VectorInfo(dim=None),
194
203
  elem_type=int,
195
204
  key_type=None,
@@ -201,10 +210,11 @@ def test_sequence_of_int() -> None:
201
210
 
202
211
 
203
212
  def test_list_with_vector_info() -> None:
204
- typ = Annotated[List[int], VectorInfo(dim=5)]
213
+ typ = Annotated[list[int], VectorInfo(dim=5)]
205
214
  result = analyze_type_info(typ)
206
215
  assert result == AnalyzedTypeInfo(
207
216
  kind="Vector",
217
+ core_type=list[int],
208
218
  vector_info=VectorInfo(dim=5),
209
219
  elem_type=int,
210
220
  key_type=None,
@@ -216,10 +226,11 @@ def test_list_with_vector_info() -> None:
216
226
 
217
227
 
218
228
  def test_dict_str_int() -> None:
219
- typ = Dict[str, int]
229
+ typ = dict[str, int]
220
230
  result = analyze_type_info(typ)
221
231
  assert result == AnalyzedTypeInfo(
222
232
  kind="KTable",
233
+ core_type=dict[str, int],
223
234
  vector_info=None,
224
235
  elem_type=(str, int),
225
236
  key_type=None,
@@ -235,6 +246,7 @@ def test_mapping_str_dataclass() -> None:
235
246
  result = analyze_type_info(typ)
236
247
  assert result == AnalyzedTypeInfo(
237
248
  kind="KTable",
249
+ core_type=Mapping[str, SimpleDataclass],
238
250
  vector_info=None,
239
251
  elem_type=(str, SimpleDataclass),
240
252
  key_type=None,
@@ -250,6 +262,7 @@ def test_dataclass() -> None:
250
262
  result = analyze_type_info(typ)
251
263
  assert result == AnalyzedTypeInfo(
252
264
  kind="Struct",
265
+ core_type=SimpleDataclass,
253
266
  vector_info=None,
254
267
  elem_type=None,
255
268
  key_type=None,
@@ -265,6 +278,7 @@ def test_named_tuple() -> None:
265
278
  result = analyze_type_info(typ)
266
279
  assert result == AnalyzedTypeInfo(
267
280
  kind="Struct",
281
+ core_type=SimpleNamedTuple,
268
282
  vector_info=None,
269
283
  elem_type=None,
270
284
  key_type=None,
@@ -280,6 +294,7 @@ def test_tuple_key_value() -> None:
280
294
  result = analyze_type_info(typ)
281
295
  assert result == AnalyzedTypeInfo(
282
296
  kind="Int64",
297
+ core_type=int,
283
298
  vector_info=None,
284
299
  elem_type=None,
285
300
  key_type=str,
@@ -295,6 +310,7 @@ def test_str() -> None:
295
310
  result = analyze_type_info(typ)
296
311
  assert result == AnalyzedTypeInfo(
297
312
  kind="Str",
313
+ core_type=str,
298
314
  vector_info=None,
299
315
  elem_type=None,
300
316
  key_type=None,
@@ -310,6 +326,7 @@ def test_bool() -> None:
310
326
  result = analyze_type_info(typ)
311
327
  assert result == AnalyzedTypeInfo(
312
328
  kind="Bool",
329
+ core_type=bool,
313
330
  vector_info=None,
314
331
  elem_type=None,
315
332
  key_type=None,
@@ -325,6 +342,7 @@ def test_bytes() -> None:
325
342
  result = analyze_type_info(typ)
326
343
  assert result == AnalyzedTypeInfo(
327
344
  kind="Bytes",
345
+ core_type=bytes,
328
346
  vector_info=None,
329
347
  elem_type=None,
330
348
  key_type=None,
@@ -340,6 +358,7 @@ def test_uuid() -> None:
340
358
  result = analyze_type_info(typ)
341
359
  assert result == AnalyzedTypeInfo(
342
360
  kind="Uuid",
361
+ core_type=uuid.UUID,
343
362
  vector_info=None,
344
363
  elem_type=None,
345
364
  key_type=None,
@@ -355,6 +374,7 @@ def test_date() -> None:
355
374
  result = analyze_type_info(typ)
356
375
  assert result == AnalyzedTypeInfo(
357
376
  kind="Date",
377
+ core_type=datetime.date,
358
378
  vector_info=None,
359
379
  elem_type=None,
360
380
  key_type=None,
@@ -370,6 +390,7 @@ def test_time() -> None:
370
390
  result = analyze_type_info(typ)
371
391
  assert result == AnalyzedTypeInfo(
372
392
  kind="Time",
393
+ core_type=datetime.time,
373
394
  vector_info=None,
374
395
  elem_type=None,
375
396
  key_type=None,
@@ -385,6 +406,7 @@ def test_timedelta() -> None:
385
406
  result = analyze_type_info(typ)
386
407
  assert result == AnalyzedTypeInfo(
387
408
  kind="TimeDelta",
409
+ core_type=datetime.timedelta,
388
410
  vector_info=None,
389
411
  elem_type=None,
390
412
  key_type=None,
@@ -400,6 +422,7 @@ def test_float() -> None:
400
422
  result = analyze_type_info(typ)
401
423
  assert result == AnalyzedTypeInfo(
402
424
  kind="Float64",
425
+ core_type=float,
403
426
  vector_info=None,
404
427
  elem_type=None,
405
428
  key_type=None,
@@ -415,6 +438,7 @@ def test_int() -> None:
415
438
  result = analyze_type_info(typ)
416
439
  assert result == AnalyzedTypeInfo(
417
440
  kind="Int64",
441
+ core_type=int,
418
442
  vector_info=None,
419
443
  elem_type=None,
420
444
  key_type=None,
@@ -430,6 +454,7 @@ def test_type_with_attributes() -> None:
430
454
  result = analyze_type_info(typ)
431
455
  assert result == AnalyzedTypeInfo(
432
456
  kind="Str",
457
+ core_type=str,
433
458
  vector_info=None,
434
459
  elem_type=None,
435
460
  key_type=None,
@@ -466,7 +491,7 @@ def test_encode_enriched_type_vector() -> None:
466
491
 
467
492
 
468
493
  def test_encode_enriched_type_ltable() -> None:
469
- typ = List[SimpleDataclass]
494
+ typ = list[SimpleDataclass]
470
495
  result = encode_enriched_type(typ)
471
496
  assert result["type"]["kind"] == "LTable"
472
497
  assert result["type"]["row"]["kind"] == "Struct"
@@ -487,6 +512,19 @@ def test_encode_enriched_type_nullable() -> None:
487
512
  assert result["nullable"] is True
488
513
 
489
514
 
515
+ def test_encode_scalar_numpy_types_schema() -> None:
516
+ for np_type, expected_kind in [
517
+ (np.int64, "Int64"),
518
+ (np.float32, "Float32"),
519
+ (np.float64, "Float64"),
520
+ ]:
521
+ schema = encode_enriched_type(np_type)
522
+ assert schema["type"]["kind"] == expected_kind, (
523
+ f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
524
+ )
525
+ assert not schema.get("nullable", False)
526
+
527
+
490
528
  def test_invalid_struct_kind() -> None:
491
529
  typ = Annotated[SimpleDataclass, TypeKind("Vector")]
492
530
  with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
@@ -494,7 +532,7 @@ def test_invalid_struct_kind() -> None:
494
532
 
495
533
 
496
534
  def test_invalid_list_kind() -> None:
497
- typ = Annotated[List[int], TypeKind("Struct")]
535
+ typ = Annotated[list[int], TypeKind("Struct")]
498
536
  with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"):
499
537
  analyze_type_info(typ)
500
538
 
cocoindex/typing.py CHANGED
@@ -1,22 +1,22 @@
1
- import typing
2
1
  import collections
3
2
  import dataclasses
4
3
  import datetime
5
- import types
6
4
  import inspect
5
+ import types
6
+ import typing
7
7
  import uuid
8
8
  from typing import (
9
+ TYPE_CHECKING,
9
10
  Annotated,
10
- NamedTuple,
11
11
  Any,
12
- KeysView,
13
- TypeVar,
14
- TYPE_CHECKING,
15
- overload,
16
12
  Generic,
17
13
  Literal,
14
+ NamedTuple,
18
15
  Protocol,
16
+ TypeVar,
17
+ overload,
19
18
  )
19
+
20
20
  import numpy as np
21
21
  from numpy.typing import NDArray
22
22
 
@@ -67,7 +67,7 @@ else:
67
67
  # No dimension provided, e.g., Vector[np.float32]
68
68
  dtype = params
69
69
  # Use NDArray for supported numeric dtypes, else list
70
- if DtypeRegistry.get_by_dtype(dtype) is not None:
70
+ if dtype in DtypeRegistry._DTYPE_TO_KIND:
71
71
  return Annotated[NDArray[dtype], VectorInfo(dim=None)]
72
72
  return Annotated[list[dtype], VectorInfo(dim=None)]
73
73
  else:
@@ -79,7 +79,7 @@ else:
79
79
  if typing.get_origin(dim_literal) is Literal
80
80
  else None
81
81
  )
82
- if DtypeRegistry.get_by_dtype(dtype) is not None:
82
+ if dtype in DtypeRegistry._DTYPE_TO_KIND:
83
83
  return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
84
84
  return Annotated[list[dtype], VectorInfo(dim=dim_val)]
85
85
 
@@ -90,6 +90,19 @@ KEY_FIELD_NAME: str = "_key"
90
90
  ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
91
91
 
92
92
 
93
+ def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
94
+ args = typing.get_args(ndarray_type)
95
+ _, dtype_spec = args
96
+ dtype_args = typing.get_args(dtype_spec)
97
+ if not dtype_args:
98
+ raise ValueError(f"Invalid dtype specification: {dtype_spec}")
99
+ return dtype_args[0]
100
+
101
+
102
+ def is_numpy_number_type(t: type) -> bool:
103
+ return isinstance(t, type) and issubclass(t, np.number)
104
+
105
+
93
106
  def is_namedtuple_type(t: type) -> bool:
94
107
  return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields")
95
108
 
@@ -100,41 +113,34 @@ def _is_struct_type(t: ElementType | None) -> bool:
100
113
  )
101
114
 
102
115
 
103
- class DtypeInfo:
104
- """Metadata for a NumPy dtype."""
105
-
106
- def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
107
- self.numpy_dtype = numpy_dtype
108
- self.kind = kind
109
- self.python_type = python_type
110
- self.annotated_type = Annotated[python_type, TypeKind(kind)]
111
-
112
-
113
116
  class DtypeRegistry:
114
117
  """
115
118
  Registry for NumPy dtypes used in CocoIndex.
116
- Provides mappings from NumPy dtypes to CocoIndex's type representation.
119
+ Maps NumPy dtypes to their CocoIndex type kind.
117
120
  """
118
121
 
119
- _mappings: dict[type, DtypeInfo] = {
120
- np.float32: DtypeInfo(np.float32, "Float32", float),
121
- np.float64: DtypeInfo(np.float64, "Float64", float),
122
- np.int64: DtypeInfo(np.int64, "Int64", int),
122
+ _DTYPE_TO_KIND: dict[ElementType, str] = {
123
+ np.float32: "Float32",
124
+ np.float64: "Float64",
125
+ np.int64: "Int64",
123
126
  }
124
127
 
125
128
  @classmethod
126
- def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
127
- """Get DtypeInfo by NumPy dtype."""
129
+ def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str:
130
+ """
131
+ Validate that the given dtype is supported, and get its CocoIndex kind by dtype.
132
+ """
128
133
  if dtype is Any:
129
134
  raise TypeError(
130
135
  "NDArray for Vector must use a concrete numpy dtype, got `Any`."
131
136
  )
132
- return cls._mappings.get(dtype)
133
-
134
- @staticmethod
135
- def supported_dtypes() -> KeysView[type]:
136
- """Get a list of supported NumPy dtypes."""
137
- return DtypeRegistry._mappings.keys()
137
+ kind = cls._DTYPE_TO_KIND.get(dtype)
138
+ if kind is None:
139
+ raise ValueError(
140
+ f"Unsupported NumPy dtype in NDArray: {dtype}. "
141
+ f"Supported dtypes: {cls._DTYPE_TO_KIND.keys()}"
142
+ )
143
+ return kind
138
144
 
139
145
 
140
146
  @dataclasses.dataclass
@@ -144,6 +150,7 @@ class AnalyzedTypeInfo:
144
150
  """
145
151
 
146
152
  kind: str
153
+ core_type: Any
147
154
  vector_info: VectorInfo | None # For Vector
148
155
  elem_type: ElementType | None # For Vector and Table
149
156
 
@@ -155,6 +162,7 @@ class AnalyzedTypeInfo:
155
162
 
156
163
  attrs: dict[str, Any] | None
157
164
  nullable: bool = False
165
+ union_variant_types: typing.List[ElementType] | None = None # For Union
158
166
 
159
167
 
160
168
  def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
@@ -175,18 +183,6 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
175
183
  if base_type is Annotated:
176
184
  annotations = t.__metadata__
177
185
  t = t.__origin__
178
- elif base_type is types.UnionType:
179
- possible_types = typing.get_args(t)
180
- non_none_types = [
181
- arg for arg in possible_types if arg not in (None, types.NoneType)
182
- ]
183
- if len(non_none_types) != 1:
184
- raise ValueError(
185
- f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}"
186
- )
187
- t = non_none_types[0]
188
- if len(possible_types) > 1:
189
- nullable = True
190
186
  else:
191
187
  break
192
188
 
@@ -205,6 +201,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
205
201
 
206
202
  struct_type: type | None = None
207
203
  elem_type: ElementType | None = None
204
+ union_variant_types: typing.List[ElementType] | None = None
208
205
  key_type: type | None = None
209
206
  np_number_type: type | None = None
210
207
  if _is_struct_type(t):
@@ -214,6 +211,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
214
211
  kind = "Struct"
215
212
  elif kind != "Struct":
216
213
  raise ValueError(f"Unexpected type kind for struct: {kind}")
214
+ elif is_numpy_number_type(t):
215
+ np_number_type = t
216
+ kind = DtypeRegistry.validate_dtype_and_get_kind(t)
217
217
  elif base_type is collections.abc.Sequence or base_type is list:
218
218
  args = typing.get_args(t)
219
219
  elem_type = args[0]
@@ -233,33 +233,35 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
233
233
  raise ValueError(f"Unexpected type kind for list: {kind}")
234
234
  elif base_type is np.ndarray:
235
235
  kind = "Vector"
236
- args = typing.get_args(t)
237
- _, dtype_spec = args
238
-
239
- dtype_args = typing.get_args(dtype_spec)
240
- if not dtype_args:
241
- raise ValueError("Invalid dtype specification for NDArray")
242
-
243
- np_number_type = dtype_args[0]
244
- dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
245
- if dtype_info is None:
246
- raise ValueError(
247
- f"Unsupported numpy dtype for NDArray: {np_number_type}. "
248
- f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
249
- )
250
- elem_type = dtype_info.annotated_type
236
+ np_number_type = t
237
+ elem_type = extract_ndarray_scalar_dtype(np_number_type)
238
+ _ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
251
239
  vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
252
240
 
253
241
  elif base_type is collections.abc.Mapping or base_type is dict:
254
242
  args = typing.get_args(t)
255
243
  elem_type = (args[0], args[1])
256
244
  kind = "KTable"
245
+ elif base_type is types.UnionType:
246
+ possible_types = typing.get_args(t)
247
+ non_none_types = [
248
+ arg for arg in possible_types if arg not in (None, types.NoneType)
249
+ ]
250
+
251
+ if len(non_none_types) == 0:
252
+ return analyze_type_info(None)
253
+
254
+ nullable = len(non_none_types) < len(possible_types)
255
+
256
+ if len(non_none_types) == 1:
257
+ result = analyze_type_info(non_none_types[0])
258
+ result.nullable = nullable
259
+ return result
260
+
261
+ kind = "Union"
262
+ union_variant_types = non_none_types
257
263
  elif kind is None:
258
- dtype_info = DtypeRegistry.get_by_dtype(t)
259
- if dtype_info is not None:
260
- kind = dtype_info.kind
261
- np_number_type = dtype_info.numpy_dtype
262
- elif t is bytes:
264
+ if t is bytes:
263
265
  kind = "Bytes"
264
266
  elif t is str:
265
267
  kind = "Str"
@@ -284,8 +286,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
284
286
 
285
287
  return AnalyzedTypeInfo(
286
288
  kind=kind,
289
+ core_type=t,
287
290
  vector_info=vector_info,
288
291
  elem_type=elem_type,
292
+ union_variant_types=union_variant_types,
289
293
  key_type=key_type,
290
294
  struct_type=struct_type,
291
295
  np_number_type=np_number_type,
@@ -345,6 +349,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
345
349
  encoded_type["element_type"] = _encode_type(elem_type_info)
346
350
  encoded_type["dimension"] = type_info.vector_info.dim
347
351
 
352
+ elif type_info.kind == "Union":
353
+ if type_info.union_variant_types is None:
354
+ raise ValueError("Union type must have a variant type list")
355
+ encoded_type["types"] = [
356
+ _encode_type(analyze_type_info(typ))
357
+ for typ in type_info.union_variant_types
358
+ ]
359
+
348
360
  elif type_info.kind in TABLE_TYPES:
349
361
  if type_info.elem_type is None:
350
362
  raise ValueError(f"{type_info.kind} type must have an element type")