cocoindex 0.1.52__cp312-cp312-macosx_11_0_arm64.whl → 0.1.53__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
cocoindex/convert.py CHANGED
@@ -6,17 +6,19 @@ import dataclasses
6
6
  import datetime
7
7
  import inspect
8
8
  import uuid
9
+ from enum import Enum
10
+ from typing import Any, Callable, Mapping, get_origin
11
+
9
12
  import numpy as np
10
13
 
11
- from enum import Enum
12
- from typing import Any, Callable, get_origin, Mapping
13
14
  from .typing import (
15
+ KEY_FIELD_NAME,
16
+ TABLE_TYPES,
17
+ DtypeRegistry,
14
18
  analyze_type_info,
15
19
  encode_enriched_type,
20
+ extract_ndarray_scalar_dtype,
16
21
  is_namedtuple_type,
17
- TABLE_TYPES,
18
- KEY_FIELD_NAME,
19
- DtypeRegistry,
20
22
  )
21
23
 
22
24
 
@@ -29,6 +31,8 @@ def encode_engine_value(value: Any) -> Any:
29
31
  ]
30
32
  if is_namedtuple_type(type(value)):
31
33
  return [encode_engine_value(getattr(value, name)) for name in value._fields]
34
+ if isinstance(value, np.number):
35
+ return value.item()
32
36
  if isinstance(value, np.ndarray):
33
37
  return value
34
38
  if isinstance(value, (list, tuple)):
@@ -86,6 +90,20 @@ def make_engine_value_decoder(
86
90
  field_path, src_type["fields"], dst_type_info.struct_type
87
91
  )
88
92
 
93
+ if dst_type_info.np_number_type is not None and src_type_kind != "Vector":
94
+ numpy_type = dst_type_info.np_number_type
95
+
96
+ def decode_numpy_scalar(value: Any) -> Any | None:
97
+ if value is None:
98
+ if dst_type_info.nullable:
99
+ return None
100
+ raise ValueError(
101
+ f"Received null for non-nullable scalar `{''.join(field_path)}`"
102
+ )
103
+ return numpy_type(value)
104
+
105
+ return decode_numpy_scalar
106
+
89
107
  if src_type_kind in TABLE_TYPES:
90
108
  field_path.append("[*]")
91
109
  elem_type_info = analyze_type_info(dst_type_info.elem_type)
@@ -127,33 +145,42 @@ def make_engine_value_decoder(
127
145
  return lambda value: uuid.UUID(bytes=value)
128
146
 
129
147
  if src_type_kind == "Vector":
130
- dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)
131
148
 
132
149
  def decode_vector(value: Any) -> Any | None:
150
+ field_path_str = "".join(field_path)
151
+ expected_dim = (
152
+ dst_type_info.vector_info.dim if dst_type_info.vector_info else None
153
+ )
154
+
133
155
  if value is None:
134
156
  if dst_type_info.nullable:
135
157
  return None
136
158
  raise ValueError(
137
- f"Received null for non-nullable vector `{''.join(field_path)}`"
159
+ f"Received null for non-nullable vector `{field_path_str}`"
138
160
  )
139
-
140
161
  if not isinstance(value, (np.ndarray, list)):
141
162
  raise TypeError(
142
- f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
163
+ f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
143
164
  )
144
- expected_dim = (
145
- dst_type_info.vector_info.dim if dst_type_info.vector_info else None
146
- )
147
165
  if expected_dim is not None and len(value) != expected_dim:
148
166
  raise ValueError(
149
- f"Vector dimension mismatch for `{''.join(field_path)}`: "
167
+ f"Vector dimension mismatch for `{field_path_str}`: "
150
168
  f"expected {expected_dim}, got {len(value)}"
151
169
  )
152
170
 
153
- # Use NDArray for supported numeric dtypes, else return list
154
- if dtype_info is not None:
155
- return np.array(value, dtype=dtype_info.numpy_dtype)
156
- return value
171
+ if dst_type_info.np_number_type is None: # for Non-NDArray vector
172
+ elem_decoder = make_engine_value_decoder(
173
+ field_path + ["[*]"],
174
+ src_type["element_type"],
175
+ dst_type_info.elem_type,
176
+ )
177
+ return [elem_decoder(v) for v in value]
178
+ else: # for NDArray vector
179
+ scalar_dtype = extract_ndarray_scalar_dtype(
180
+ dst_type_info.np_number_type
181
+ )
182
+ _ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
183
+ return np.array(value, dtype=scalar_dtype)
157
184
 
158
185
  return decode_vector
159
186
 
cocoindex/llm.py CHANGED
@@ -9,6 +9,8 @@ class LlmApiType(Enum):
9
9
  OLLAMA = "Ollama"
10
10
  GEMINI = "Gemini"
11
11
  ANTHROPIC = "Anthropic"
12
+ LITE_LLM = "LiteLlm"
13
+ OPEN_ROUTER = "OpenRouter"
12
14
 
13
15
 
14
16
  @dataclass
@@ -1,22 +1,25 @@
1
- import uuid
2
1
  import datetime
2
+ import uuid
3
3
  from dataclasses import dataclass, make_dataclass
4
- from typing import NamedTuple, Literal, Any, Callable, Union
4
+ from typing import Annotated, Any, Callable, Literal, NamedTuple
5
+
6
+ import numpy as np
5
7
  import pytest
8
+ from numpy.typing import NDArray
9
+
6
10
  import cocoindex
7
- from cocoindex.typing import (
8
- encode_enriched_type,
9
- Vector,
10
- Float32,
11
- Float64,
12
- )
13
11
  from cocoindex.convert import (
12
+ dump_engine_object,
14
13
  encode_engine_value,
15
14
  make_engine_value_decoder,
16
- dump_engine_object,
17
15
  )
18
- import numpy as np
19
- from numpy.typing import NDArray
16
+ from cocoindex.typing import (
17
+ Float32,
18
+ Float64,
19
+ TypeKind,
20
+ Vector,
21
+ encode_enriched_type,
22
+ )
20
23
 
21
24
 
22
25
  @dataclass
@@ -128,6 +131,19 @@ def test_encode_engine_value_date_time_types() -> None:
128
131
  assert encode_engine_value(dt) == dt
129
132
 
130
133
 
134
+ def test_encode_scalar_numpy_values() -> None:
135
+ """Test encoding scalar NumPy values to engine-compatible values."""
136
+ test_cases = [
137
+ (np.int64(42), 42),
138
+ (np.float32(3.14), pytest.approx(3.14)),
139
+ (np.float64(2.718), pytest.approx(2.718)),
140
+ ]
141
+ for np_value, expected in test_cases:
142
+ encoded = encode_engine_value(np_value)
143
+ assert encoded == expected
144
+ assert isinstance(encoded, (int, float))
145
+
146
+
131
147
  def test_encode_engine_value_struct() -> None:
132
148
  order = Order(order_id="O123", name="mixed nuts", price=25.0)
133
149
  assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
@@ -213,6 +229,47 @@ def test_roundtrip_basic_types() -> None:
213
229
  )
214
230
 
215
231
 
232
+ def test_decode_scalar_numpy_values() -> None:
233
+ test_cases = [
234
+ ({"kind": "Int64"}, np.int64, 42, np.int64(42)),
235
+ ({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
236
+ ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
237
+ ]
238
+ for src_type, dst_type, input_value, expected in test_cases:
239
+ decoder = make_engine_value_decoder(["field"], src_type, dst_type)
240
+ result = decoder(input_value)
241
+ assert isinstance(result, dst_type)
242
+ assert result == expected
243
+
244
+
245
+ def test_non_ndarray_vector_decoding() -> None:
246
+ # Test list[np.float64]
247
+ src_type = {
248
+ "kind": "Vector",
249
+ "element_type": {"kind": "Float64"},
250
+ "dimension": None,
251
+ }
252
+ dst_type_float = list[np.float64]
253
+ decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
254
+ input_numbers = [1.0, 2.0, 3.0]
255
+ result = decoder(input_numbers)
256
+ assert isinstance(result, list)
257
+ assert all(isinstance(x, np.float64) for x in result)
258
+ assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
259
+
260
+ # Test list[Uuid]
261
+ src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
262
+ dst_type_uuid = list[uuid.UUID]
263
+ decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
264
+ uuid1 = uuid.uuid4()
265
+ uuid2 = uuid.uuid4()
266
+ input_bytes = [uuid1.bytes, uuid2.bytes]
267
+ result = decoder(input_bytes)
268
+ assert isinstance(result, list)
269
+ assert all(isinstance(x, uuid.UUID) for x in result)
270
+ assert result == [uuid1, uuid2]
271
+
272
+
216
273
  @pytest.mark.parametrize(
217
274
  "data_type, engine_val, expected",
218
275
  [
@@ -565,12 +622,6 @@ Float64VectorTypeNoDim = Vector[np.float64]
565
622
  Float32VectorType = Vector[np.float32, Literal[3]]
566
623
  Float64VectorType = Vector[np.float64, Literal[3]]
567
624
  Int64VectorType = Vector[np.int64, Literal[3]]
568
- Int32VectorType = Vector[np.int32, Literal[3]]
569
- UInt8VectorType = Vector[np.uint8, Literal[3]]
570
- UInt16VectorType = Vector[np.uint16, Literal[3]]
571
- UInt32VectorType = Vector[np.uint32, Literal[3]]
572
- UInt64VectorType = Vector[np.uint64, Literal[3]]
573
- StrVectorType = Vector[str]
574
625
  NDArrayFloat32Type = NDArray[np.float32]
575
626
  NDArrayFloat64Type = NDArray[np.float64]
576
627
  NDArrayInt64Type = NDArray[np.int64]
@@ -767,19 +818,19 @@ def test_full_roundtrip_vector_numeric_types() -> None:
767
818
  value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
768
819
  validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
769
820
  value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
770
- with pytest.raises(ValueError, match="type unsupported yet"):
821
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
771
822
  validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
772
823
  value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
773
- with pytest.raises(ValueError, match="type unsupported yet"):
824
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
774
825
  validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
775
826
  value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
776
- with pytest.raises(ValueError, match="type unsupported yet"):
827
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
777
828
  validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
778
829
  value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
779
- with pytest.raises(ValueError, match="type unsupported yet"):
830
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
780
831
  validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
781
832
  value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
782
- with pytest.raises(ValueError, match="type unsupported yet"):
833
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
783
834
  validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
784
835
 
785
836
 
@@ -808,7 +859,88 @@ def test_roundtrip_dimension_mismatch() -> None:
808
859
  validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
809
860
 
810
861
 
811
- def test_roundtrip_list_backward_compatibility() -> None:
812
- """Test full roundtrip for list-based vectors for backward compatibility."""
813
- value_list: list[int] = [1, 2, 3]
814
- validate_full_roundtrip(value_list, list[int])
862
+ def test_full_roundtrip_scalar_numeric_types() -> None:
863
+ """Test full roundtrip for scalar NumPy numeric types."""
864
+ # Test supported scalar types
865
+ validate_full_roundtrip(np.int64(42), np.int64)
866
+ validate_full_roundtrip(np.float32(3.14), np.float32)
867
+ validate_full_roundtrip(np.float64(2.718), np.float64)
868
+
869
+ # Test unsupported scalar types
870
+ for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
871
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
872
+ validate_full_roundtrip(unsupported_type(1), unsupported_type)
873
+
874
+
875
+ def test_full_roundtrip_nullable_scalar() -> None:
876
+ """Test full roundtrip for nullable scalar NumPy types."""
877
+ # Test with non-null values
878
+ validate_full_roundtrip(np.int64(42), np.int64 | None)
879
+ validate_full_roundtrip(np.float32(3.14), np.float32 | None)
880
+ validate_full_roundtrip(np.float64(2.718), np.float64 | None)
881
+
882
+ # Test with None
883
+ validate_full_roundtrip(None, np.int64 | None)
884
+ validate_full_roundtrip(None, np.float32 | None)
885
+ validate_full_roundtrip(None, np.float64 | None)
886
+
887
+
888
+ def test_full_roundtrip_scalar_in_struct() -> None:
889
+ """Test full roundtrip for scalar NumPy types in a dataclass."""
890
+
891
+ @dataclass
892
+ class NumericStruct:
893
+ int_field: np.int64
894
+ float32_field: np.float32
895
+ float64_field: np.float64
896
+
897
+ instance = NumericStruct(
898
+ int_field=np.int64(42),
899
+ float32_field=np.float32(3.14),
900
+ float64_field=np.float64(2.718),
901
+ )
902
+ validate_full_roundtrip(instance, NumericStruct)
903
+
904
+
905
+ def test_full_roundtrip_scalar_in_nested_struct() -> None:
906
+ """Test full roundtrip for scalar NumPy types in a nested struct."""
907
+
908
+ @dataclass
909
+ class InnerStruct:
910
+ value: np.float64
911
+
912
+ @dataclass
913
+ class OuterStruct:
914
+ inner: InnerStruct
915
+ count: np.int64
916
+
917
+ instance = OuterStruct(
918
+ inner=InnerStruct(value=np.float64(2.718)),
919
+ count=np.int64(1),
920
+ )
921
+ validate_full_roundtrip(instance, OuterStruct)
922
+
923
+
924
+ def test_full_roundtrip_scalar_with_python_types() -> None:
925
+ """Test full roundtrip for structs mixing NumPy and Python scalar types."""
926
+
927
+ @dataclass
928
+ class MixedStruct:
929
+ numpy_int: np.int64
930
+ python_int: int
931
+ numpy_float: np.float64
932
+ python_float: float
933
+ string: str
934
+ annotated_int: Annotated[np.int64, TypeKind("int")]
935
+ annotated_float: Float32
936
+
937
+ instance = MixedStruct(
938
+ numpy_int=np.int64(42),
939
+ python_int=43,
940
+ numpy_float=np.float64(2.718),
941
+ python_float=3.14,
942
+ string="hello, world",
943
+ annotated_int=np.int64(42),
944
+ annotated_float=2.0,
945
+ )
946
+ validate_full_roundtrip(instance, MixedStruct)
@@ -1,30 +1,21 @@
1
1
  import dataclasses
2
2
  import datetime
3
3
  import uuid
4
- from typing import (
5
- Annotated,
6
- List,
7
- Dict,
8
- Literal,
9
- Any,
10
- get_args,
11
- NamedTuple,
12
- )
13
- from collections.abc import Sequence, Mapping
14
- import pytest
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Annotated, Any, Dict, List, Literal, NamedTuple, get_args, get_origin
6
+
15
7
  import numpy as np
8
+ import pytest
16
9
  from numpy.typing import NDArray
17
10
 
18
11
  from cocoindex.typing import (
19
- analyze_type_info,
12
+ AnalyzedTypeInfo,
13
+ TypeAttr,
14
+ TypeKind,
20
15
  Vector,
21
16
  VectorInfo,
22
- TypeKind,
23
- TypeAttr,
24
- Float32,
25
- Float64,
17
+ analyze_type_info,
26
18
  encode_enriched_type,
27
- AnalyzedTypeInfo,
28
19
  )
29
20
 
30
21
 
@@ -42,61 +33,57 @@ class SimpleNamedTuple(NamedTuple):
42
33
  def test_ndarray_float32_no_dim() -> None:
43
34
  typ = NDArray[np.float32]
44
35
  result = analyze_type_info(typ)
45
- assert result == AnalyzedTypeInfo(
46
- kind="Vector",
47
- vector_info=VectorInfo(dim=None),
48
- elem_type=Float32,
49
- key_type=None,
50
- struct_type=None,
51
- np_number_type=np.float32,
52
- attrs=None,
53
- nullable=False,
54
- )
36
+ assert result.kind == "Vector"
37
+ assert result.vector_info == VectorInfo(dim=None)
38
+ assert result.elem_type == np.float32
39
+ assert result.key_type is None
40
+ assert result.struct_type is None
41
+ assert result.nullable is False
42
+ assert result.np_number_type is not None
43
+ assert get_origin(result.np_number_type) == np.ndarray
44
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
55
45
 
56
46
 
57
47
  def test_vector_float32_no_dim() -> None:
58
48
  typ = Vector[np.float32]
59
49
  result = analyze_type_info(typ)
60
- assert result == AnalyzedTypeInfo(
61
- kind="Vector",
62
- vector_info=VectorInfo(dim=None),
63
- elem_type=Float32,
64
- key_type=None,
65
- struct_type=None,
66
- np_number_type=np.float32,
67
- attrs=None,
68
- nullable=False,
69
- )
50
+ assert result.kind == "Vector"
51
+ assert result.vector_info == VectorInfo(dim=None)
52
+ assert result.elem_type == np.float32
53
+ assert result.key_type is None
54
+ assert result.struct_type is None
55
+ assert result.nullable is False
56
+ assert result.np_number_type is not None
57
+ assert get_origin(result.np_number_type) == np.ndarray
58
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
70
59
 
71
60
 
72
61
  def test_ndarray_float64_with_dim() -> None:
73
62
  typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
74
63
  result = analyze_type_info(typ)
75
- assert result == AnalyzedTypeInfo(
76
- kind="Vector",
77
- vector_info=VectorInfo(dim=128),
78
- elem_type=Float64,
79
- key_type=None,
80
- struct_type=None,
81
- np_number_type=np.float64,
82
- attrs=None,
83
- nullable=False,
84
- )
64
+ assert result.kind == "Vector"
65
+ assert result.vector_info == VectorInfo(dim=128)
66
+ assert result.elem_type == np.float64
67
+ assert result.key_type is None
68
+ assert result.struct_type is None
69
+ assert result.nullable is False
70
+ assert result.np_number_type is not None
71
+ assert get_origin(result.np_number_type) == np.ndarray
72
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float64]
85
73
 
86
74
 
87
75
  def test_vector_float32_with_dim() -> None:
88
76
  typ = Vector[np.float32, Literal[384]]
89
77
  result = analyze_type_info(typ)
90
- assert result == AnalyzedTypeInfo(
91
- kind="Vector",
92
- vector_info=VectorInfo(dim=384),
93
- elem_type=Float32,
94
- key_type=None,
95
- struct_type=None,
96
- np_number_type=np.float32,
97
- attrs=None,
98
- nullable=False,
99
- )
78
+ assert result.kind == "Vector"
79
+ assert result.vector_info == VectorInfo(dim=384)
80
+ assert result.elem_type == np.float32
81
+ assert result.key_type is None
82
+ assert result.struct_type is None
83
+ assert result.nullable is False
84
+ assert result.np_number_type is not None
85
+ assert get_origin(result.np_number_type) == np.ndarray
86
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
100
87
 
101
88
 
102
89
  def test_ndarray_int64_no_dim() -> None:
@@ -104,30 +91,49 @@ def test_ndarray_int64_no_dim() -> None:
104
91
  result = analyze_type_info(typ)
105
92
  assert result.kind == "Vector"
106
93
  assert result.vector_info == VectorInfo(dim=None)
107
- assert get_args(result.elem_type) == (int, TypeKind("Int64"))
108
- assert not result.nullable
94
+ assert result.elem_type == np.int64
95
+ assert result.nullable is False
96
+ assert result.np_number_type is not None
97
+ assert get_origin(result.np_number_type) == np.ndarray
98
+ assert get_args(result.np_number_type)[1] == np.dtype[np.int64]
109
99
 
110
100
 
111
101
  def test_nullable_ndarray() -> None:
112
102
  typ = NDArray[np.float32] | None
113
103
  result = analyze_type_info(typ)
114
- assert result == AnalyzedTypeInfo(
115
- kind="Vector",
116
- vector_info=VectorInfo(dim=None),
117
- elem_type=Float32,
118
- key_type=None,
119
- struct_type=None,
120
- np_number_type=np.float32,
121
- attrs=None,
122
- nullable=True,
123
- )
104
+ assert result.kind == "Vector"
105
+ assert result.vector_info == VectorInfo(dim=None)
106
+ assert result.elem_type == np.float32
107
+ assert result.key_type is None
108
+ assert result.struct_type is None
109
+ assert result.nullable is True
110
+ assert result.np_number_type is not None
111
+ assert get_origin(result.np_number_type) == np.ndarray
112
+ assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
113
+
114
+
115
+ def test_scalar_numpy_types() -> None:
116
+ for np_type, expected_kind in [
117
+ (np.int64, "Int64"),
118
+ (np.float32, "Float32"),
119
+ (np.float64, "Float64"),
120
+ ]:
121
+ type_info = analyze_type_info(np_type)
122
+ assert type_info.kind == expected_kind, (
123
+ f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
124
+ )
125
+ assert type_info.np_number_type == np_type, (
126
+ f"Expected {np_type}, got {type_info.np_number_type}"
127
+ )
128
+ assert type_info.elem_type is None
129
+ assert type_info.vector_info is None
124
130
 
125
131
 
126
132
  def test_vector_str() -> None:
127
133
  typ = Vector[str]
128
134
  result = analyze_type_info(typ)
129
135
  assert result.kind == "Vector"
130
- assert result.elem_type == str
136
+ assert result.elem_type is str
131
137
  assert result.vector_info == VectorInfo(dim=None)
132
138
 
133
139
 
@@ -143,7 +149,7 @@ def test_non_numpy_vector() -> None:
143
149
  typ = Vector[float, Literal[3]]
144
150
  result = analyze_type_info(typ)
145
151
  assert result.kind == "Vector"
146
- assert result.elem_type == float
152
+ assert result.elem_type is float
147
153
  assert result.vector_info == VectorInfo(dim=3)
148
154
 
149
155
 
@@ -487,6 +493,19 @@ def test_encode_enriched_type_nullable() -> None:
487
493
  assert result["nullable"] is True
488
494
 
489
495
 
496
+ def test_encode_scalar_numpy_types_schema() -> None:
497
+ for np_type, expected_kind in [
498
+ (np.int64, "Int64"),
499
+ (np.float32, "Float32"),
500
+ (np.float64, "Float64"),
501
+ ]:
502
+ schema = encode_enriched_type(np_type)
503
+ assert schema["type"]["kind"] == expected_kind, (
504
+ f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
505
+ )
506
+ assert not schema.get("nullable", False)
507
+
508
+
490
509
  def test_invalid_struct_kind() -> None:
491
510
  typ = Annotated[SimpleDataclass, TypeKind("Vector")]
492
511
  with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
cocoindex/typing.py CHANGED
@@ -1,22 +1,22 @@
1
- import typing
2
1
  import collections
3
2
  import dataclasses
4
3
  import datetime
5
- import types
6
4
  import inspect
5
+ import types
6
+ import typing
7
7
  import uuid
8
8
  from typing import (
9
+ TYPE_CHECKING,
9
10
  Annotated,
10
- NamedTuple,
11
11
  Any,
12
- KeysView,
13
- TypeVar,
14
- TYPE_CHECKING,
15
- overload,
16
12
  Generic,
17
13
  Literal,
14
+ NamedTuple,
18
15
  Protocol,
16
+ TypeVar,
17
+ overload,
19
18
  )
19
+
20
20
  import numpy as np
21
21
  from numpy.typing import NDArray
22
22
 
@@ -67,7 +67,7 @@ else:
67
67
  # No dimension provided, e.g., Vector[np.float32]
68
68
  dtype = params
69
69
  # Use NDArray for supported numeric dtypes, else list
70
- if DtypeRegistry.get_by_dtype(dtype) is not None:
70
+ if dtype in DtypeRegistry._DTYPE_TO_KIND:
71
71
  return Annotated[NDArray[dtype], VectorInfo(dim=None)]
72
72
  return Annotated[list[dtype], VectorInfo(dim=None)]
73
73
  else:
@@ -79,7 +79,7 @@ else:
79
79
  if typing.get_origin(dim_literal) is Literal
80
80
  else None
81
81
  )
82
- if DtypeRegistry.get_by_dtype(dtype) is not None:
82
+ if dtype in DtypeRegistry._DTYPE_TO_KIND:
83
83
  return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
84
84
  return Annotated[list[dtype], VectorInfo(dim=dim_val)]
85
85
 
@@ -90,6 +90,19 @@ KEY_FIELD_NAME: str = "_key"
90
90
  ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
91
91
 
92
92
 
93
+ def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
94
+ args = typing.get_args(ndarray_type)
95
+ _, dtype_spec = args
96
+ dtype_args = typing.get_args(dtype_spec)
97
+ if not dtype_args:
98
+ raise ValueError(f"Invalid dtype specification: {dtype_spec}")
99
+ return dtype_args[0]
100
+
101
+
102
+ def is_numpy_number_type(t: type) -> bool:
103
+ return isinstance(t, type) and issubclass(t, np.number)
104
+
105
+
93
106
  def is_namedtuple_type(t: type) -> bool:
94
107
  return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields")
95
108
 
@@ -100,41 +113,34 @@ def _is_struct_type(t: ElementType | None) -> bool:
100
113
  )
101
114
 
102
115
 
103
- class DtypeInfo:
104
- """Metadata for a NumPy dtype."""
105
-
106
- def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
107
- self.numpy_dtype = numpy_dtype
108
- self.kind = kind
109
- self.python_type = python_type
110
- self.annotated_type = Annotated[python_type, TypeKind(kind)]
111
-
112
-
113
116
  class DtypeRegistry:
114
117
  """
115
118
  Registry for NumPy dtypes used in CocoIndex.
116
- Provides mappings from NumPy dtypes to CocoIndex's type representation.
119
+ Maps NumPy dtypes to their CocoIndex type kind.
117
120
  """
118
121
 
119
- _mappings: dict[type, DtypeInfo] = {
120
- np.float32: DtypeInfo(np.float32, "Float32", float),
121
- np.float64: DtypeInfo(np.float64, "Float64", float),
122
- np.int64: DtypeInfo(np.int64, "Int64", int),
122
+ _DTYPE_TO_KIND: dict[ElementType, str] = {
123
+ np.float32: "Float32",
124
+ np.float64: "Float64",
125
+ np.int64: "Int64",
123
126
  }
124
127
 
125
128
  @classmethod
126
- def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
127
- """Get DtypeInfo by NumPy dtype."""
129
+ def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str:
130
+ """
131
+ Validate that the given dtype is supported, and get its CocoIndex kind by dtype.
132
+ """
128
133
  if dtype is Any:
129
134
  raise TypeError(
130
135
  "NDArray for Vector must use a concrete numpy dtype, got `Any`."
131
136
  )
132
- return cls._mappings.get(dtype)
133
-
134
- @staticmethod
135
- def supported_dtypes() -> KeysView[type]:
136
- """Get a list of supported NumPy dtypes."""
137
- return DtypeRegistry._mappings.keys()
137
+ kind = cls._DTYPE_TO_KIND.get(dtype)
138
+ if kind is None:
139
+ raise ValueError(
140
+ f"Unsupported NumPy dtype in NDArray: {dtype}. "
141
+ f"Supported dtypes: {cls._DTYPE_TO_KIND.keys()}"
142
+ )
143
+ return kind
138
144
 
139
145
 
140
146
  @dataclasses.dataclass
@@ -214,6 +220,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
214
220
  kind = "Struct"
215
221
  elif kind != "Struct":
216
222
  raise ValueError(f"Unexpected type kind for struct: {kind}")
223
+ elif is_numpy_number_type(t):
224
+ np_number_type = t
225
+ kind = DtypeRegistry.validate_dtype_and_get_kind(t)
217
226
  elif base_type is collections.abc.Sequence or base_type is list:
218
227
  args = typing.get_args(t)
219
228
  elem_type = args[0]
@@ -233,21 +242,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
233
242
  raise ValueError(f"Unexpected type kind for list: {kind}")
234
243
  elif base_type is np.ndarray:
235
244
  kind = "Vector"
236
- args = typing.get_args(t)
237
- _, dtype_spec = args
238
-
239
- dtype_args = typing.get_args(dtype_spec)
240
- if not dtype_args:
241
- raise ValueError("Invalid dtype specification for NDArray")
242
-
243
- np_number_type = dtype_args[0]
244
- dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
245
- if dtype_info is None:
246
- raise ValueError(
247
- f"Unsupported numpy dtype for NDArray: {np_number_type}. "
248
- f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
249
- )
250
- elem_type = dtype_info.annotated_type
245
+ np_number_type = t
246
+ elem_type = extract_ndarray_scalar_dtype(np_number_type)
247
+ _ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
251
248
  vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
252
249
 
253
250
  elif base_type is collections.abc.Mapping or base_type is dict:
@@ -255,11 +252,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
255
252
  elem_type = (args[0], args[1])
256
253
  kind = "KTable"
257
254
  elif kind is None:
258
- dtype_info = DtypeRegistry.get_by_dtype(t)
259
- if dtype_info is not None:
260
- kind = dtype_info.kind
261
- np_number_type = dtype_info.numpy_dtype
262
- elif t is bytes:
255
+ if t is bytes:
263
256
  kind = "Bytes"
264
257
  elif t is str:
265
258
  kind = "Str"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cocoindex
3
- Version: 0.1.52
3
+ Version: 0.1.53
4
4
  Requires-Dist: sentence-transformers>=3.3.1
5
5
  Requires-Dist: click>=8.1.8
6
6
  Requires-Dist: rich>=14.0.0
@@ -1,17 +1,17 @@
1
- cocoindex-0.1.52.dist-info/METADATA,sha256=8psE6rLZFYrw7cBEU-rVD1JGf4ezIIHMn-Yk2ZVNEKY,9874
2
- cocoindex-0.1.52.dist-info/WHEEL,sha256=vahFoO0M6DZmYz1cr_3GC_BJXbjcSJJ4jR_JC04LGJc,104
3
- cocoindex-0.1.52.dist-info/entry_points.txt,sha256=_NretjYVzBdNTn7dK-zgwr7YfG2afz1u1uSE-5bZXF8,46
4
- cocoindex-0.1.52.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1
+ cocoindex-0.1.53.dist-info/METADATA,sha256=HQKpE-cc_Dgi2zpIZOAX8aVCl5Zt8hd__egLbAdlOSU,9874
2
+ cocoindex-0.1.53.dist-info/WHEEL,sha256=vahFoO0M6DZmYz1cr_3GC_BJXbjcSJJ4jR_JC04LGJc,104
3
+ cocoindex-0.1.53.dist-info/entry_points.txt,sha256=_NretjYVzBdNTn7dK-zgwr7YfG2afz1u1uSE-5bZXF8,46
4
+ cocoindex-0.1.53.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
5
  cocoindex/__init__.py,sha256=3potlGouaGAznMTEqU5QKuB934qpD1xR6ZH6EL3YMWw,1782
6
- cocoindex/_engine.cpython-312-darwin.so,sha256=MuIjmOSOeRSPbnynxicEE3YS9XNwJpd_noB_Yc4L4Ws,57344192
6
+ cocoindex/_engine.cpython-312-darwin.so,sha256=7BcD6RlbQ1I1gn8ponJxeRsmwQP4W2t8160TurmWR3k,57344864
7
7
  cocoindex/auth_registry.py,sha256=1XqO7ibjmBBd8i11XSJTvTgdz8p1ptW-ZpuSgo_5zzk,716
8
8
  cocoindex/cli.py,sha256=joSucqLibjZHs27fwNPzuMPGLeVJJZqaUbeSM-OnEyA,18223
9
- cocoindex/convert.py,sha256=s5T2IQ1tMBT9JKTzqQtQ-I0sUrfjBGKAnMjDjY6LdrI,8785
9
+ cocoindex/convert.py,sha256=qP_LRTg_B8YPEeHprnX4kl-QV4ks0oqQzH7A1totZRA,9787
10
10
  cocoindex/flow.py,sha256=uQKvIyWSysSdhNFI87QD6f2kbF-2R6zFMLZWBEk4rZU,30015
11
11
  cocoindex/functions.py,sha256=9A61Jj5a3vQoI2MIAhjXvJrDxSzDhe6VncQWbiVtwcg,2393
12
12
  cocoindex/index.py,sha256=j93B9jEvvLXHtpzKWL88SY6wCGEoPgpsQhEGHlyYGFg,540
13
13
  cocoindex/lib.py,sha256=BeRUn3RqE_wSsVtsgCzbFFKe1LXgRyRmMOcmwWBuEXo,2940
14
- cocoindex/llm.py,sha256=KO-R4mrAWtxXD82-Yv5ixpkKMVfkwpbdWwqPVZygLu4,352
14
+ cocoindex/llm.py,sha256=gbjAsD8Y9QZyICCQfSDZYQU2pIA6GqG4jKbRpbnQniA,408
15
15
  cocoindex/op.py,sha256=Z7V9Fdz4qeTeozzKmp1Dk1lPUP0GBgVgD_vBEhTJS5Y,11810
16
16
  cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  cocoindex/runtime.py,sha256=bAdHYaXFWiiUWyAgzmKTeaAaRR0D_AmaqVCIdPO-v00,1056
@@ -20,9 +20,9 @@ cocoindex/setup.py,sha256=u5dYZFKfz4yZLiGHD0guNaR0s4zY9JAoZWrWHpAHw_0,773
20
20
  cocoindex/sources.py,sha256=JCnOhv1w4o28e03i7yvo4ESicWYAhckkBg5bQlxNH4U,1330
21
21
  cocoindex/targets.py,sha256=Nfh_tpFd1goTnS_cxBjIs4j9zl3Z4Z1JomAQ1dl3Sic,2796
22
22
  cocoindex/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
23
- cocoindex/tests/test_convert.py,sha256=mlCJ1U_T7sPDk1Dw78V1BKaYbfqjDjGAaOZT0zMYOwg,28960
23
+ cocoindex/tests/test_convert.py,sha256=iDA4j-X_r1KeJgaa3osU7SyY8rBc57R0YVW8G4Z83YI,33235
24
24
  cocoindex/tests/test_optional_database.py,sha256=snAmkNa6wtOSaxoZE1HgjvL5v_ylitt3Jt_9df4Cgdc,8506
25
- cocoindex/tests/test_typing.py,sha256=4cy_6kPyGxsM6qX-O8K-jXhsgDr_FXePplERgky22qI,12313
26
- cocoindex/typing.py,sha256=PEV_ds7AGHAHkZmX_fWg-2mFiEXDx-V-Rlhv11_8rtk,12387
25
+ cocoindex/tests/test_typing.py,sha256=jeZFUyzUTC6hdyo1uERhlKaKAiBKeKCrZ20OLGBkluo,14184
26
+ cocoindex/typing.py,sha256=yLtXdXEF0sYe3cMxTR__g39Gi286CnOKGI0T-hjaCrw,11999
27
27
  cocoindex/utils.py,sha256=hUhX-XV6XGCtJSEIpBOuDv6VvqImwPlgBxztBTw7u0U,598
28
- cocoindex-0.1.52.dist-info/RECORD,,
28
+ cocoindex-0.1.53.dist-info/RECORD,,