cocoindex 0.1.71__cp313-cp313-macosx_11_0_arm64.whl → 0.1.73__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
cocoindex/convert.py CHANGED
@@ -13,12 +13,19 @@ import numpy as np
13
13
  from .typing import (
14
14
  KEY_FIELD_NAME,
15
15
  TABLE_TYPES,
16
- DtypeRegistry,
17
16
  analyze_type_info,
18
17
  encode_enriched_type,
19
- extract_ndarray_scalar_dtype,
20
18
  is_namedtuple_type,
21
19
  is_struct_type,
20
+ AnalyzedTypeInfo,
21
+ AnalyzedAnyType,
22
+ AnalyzedDictType,
23
+ AnalyzedListType,
24
+ AnalyzedBasicType,
25
+ AnalyzedUnionType,
26
+ AnalyzedUnknownType,
27
+ AnalyzedStructType,
28
+ is_numpy_number_type,
22
29
  )
23
30
 
24
31
 
@@ -79,46 +86,88 @@ def make_engine_value_decoder(
79
86
  Returns:
80
87
  A decoder from an engine value to a Python value.
81
88
  """
89
+
82
90
  src_type_kind = src_type["kind"]
83
91
 
84
- dst_is_any = (
85
- dst_annotation is None
86
- or dst_annotation is inspect.Parameter.empty
87
- or dst_annotation is Any
88
- )
89
- if dst_is_any:
90
- if src_type_kind == "Union":
91
- return lambda value: value[1]
92
- if src_type_kind == "Struct":
93
- return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
94
- if src_type_kind in TABLE_TYPES:
95
- if src_type_kind == "LTable":
92
+ dst_type_info = analyze_type_info(dst_annotation)
93
+ dst_type_variant = dst_type_info.variant
94
+
95
+ if isinstance(dst_type_variant, AnalyzedUnknownType):
96
+ raise ValueError(
97
+ f"Type mismatch for `{''.join(field_path)}`: "
98
+ f"declared `{dst_type_info.core_type}`, an unsupported type"
99
+ )
100
+
101
+ if src_type_kind == "Struct":
102
+ return _make_engine_struct_value_decoder(
103
+ field_path,
104
+ src_type["fields"],
105
+ dst_type_info,
106
+ )
107
+
108
+ if src_type_kind in TABLE_TYPES:
109
+ field_path.append("[*]")
110
+ engine_fields_schema = src_type["row"]["fields"]
111
+
112
+ if src_type_kind == "LTable":
113
+ if isinstance(dst_type_variant, AnalyzedAnyType):
96
114
  return _make_engine_ltable_to_list_dict_decoder(
97
- field_path, src_type["row"]["fields"]
115
+ field_path, engine_fields_schema
116
+ )
117
+ if not isinstance(dst_type_variant, AnalyzedListType):
118
+ raise ValueError(
119
+ f"Type mismatch for `{''.join(field_path)}`: "
120
+ f"declared `{dst_type_info.core_type}`, a list type expected"
98
121
  )
99
- elif src_type_kind == "KTable":
122
+ row_decoder = _make_engine_struct_value_decoder(
123
+ field_path,
124
+ engine_fields_schema,
125
+ analyze_type_info(dst_type_variant.elem_type),
126
+ )
127
+
128
+ def decode(value: Any) -> Any | None:
129
+ if value is None:
130
+ return None
131
+ return [row_decoder(v) for v in value]
132
+
133
+ elif src_type_kind == "KTable":
134
+ if isinstance(dst_type_variant, AnalyzedAnyType):
100
135
  return _make_engine_ktable_to_dict_dict_decoder(
101
- field_path, src_type["row"]["fields"]
136
+ field_path, engine_fields_schema
137
+ )
138
+ if not isinstance(dst_type_variant, AnalyzedDictType):
139
+ raise ValueError(
140
+ f"Type mismatch for `{''.join(field_path)}`: "
141
+ f"declared `{dst_type_info.core_type}`, a dict type expected"
102
142
  )
103
- return lambda value: value
104
143
 
105
- # Handle struct -> dict binding for explicit dict annotations
106
- is_dict_annotation = False
107
- if dst_annotation is dict:
108
- is_dict_annotation = True
109
- elif getattr(dst_annotation, "__origin__", None) is dict:
110
- args = getattr(dst_annotation, "__args__", ())
111
- if args == (str, Any):
112
- is_dict_annotation = True
113
- if is_dict_annotation and src_type_kind == "Struct":
114
- return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
144
+ key_field_schema = engine_fields_schema[0]
145
+ field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
146
+ key_decoder = make_engine_value_decoder(
147
+ field_path, key_field_schema["type"], dst_type_variant.key_type
148
+ )
149
+ field_path.pop()
150
+ value_decoder = _make_engine_struct_value_decoder(
151
+ field_path,
152
+ engine_fields_schema[1:],
153
+ analyze_type_info(dst_type_variant.value_type),
154
+ )
115
155
 
116
- dst_type_info = analyze_type_info(dst_annotation)
156
+ def decode(value: Any) -> Any | None:
157
+ if value is None:
158
+ return None
159
+ return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
160
+
161
+ field_path.pop()
162
+ return decode
117
163
 
118
164
  if src_type_kind == "Union":
165
+ if isinstance(dst_type_variant, AnalyzedAnyType):
166
+ return lambda value: value[1]
167
+
119
168
  dst_type_variants = (
120
- dst_type_info.union_variant_types
121
- if dst_type_info.union_variant_types is not None
169
+ dst_type_variant.variant_types
170
+ if isinstance(dst_type_variant, AnalyzedUnionType)
122
171
  else [dst_annotation]
123
172
  )
124
173
  src_type_variants = src_type["types"]
@@ -142,43 +191,33 @@ def make_engine_value_decoder(
142
191
  decoders.append(decoder)
143
192
  return lambda value: decoders[value[0]](value[1])
144
193
 
145
- if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
146
- raise ValueError(
147
- f"Type mismatch for `{''.join(field_path)}`: "
148
- f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
149
- )
150
-
151
- if dst_type_info.kind in ("Float32", "Float64", "Int64"):
152
- dst_core_type = dst_type_info.core_type
153
-
154
- def decode_scalar(value: Any) -> Any | None:
155
- if value is None:
156
- if dst_type_info.nullable:
157
- return None
158
- raise ValueError(
159
- f"Received null for non-nullable scalar `{''.join(field_path)}`"
160
- )
161
- return dst_core_type(value)
162
-
163
- return decode_scalar
194
+ if isinstance(dst_type_variant, AnalyzedAnyType):
195
+ return lambda value: value
164
196
 
165
197
  if src_type_kind == "Vector":
166
198
  field_path_str = "".join(field_path)
199
+ if not isinstance(dst_type_variant, AnalyzedListType):
200
+ raise ValueError(
201
+ f"Type mismatch for `{''.join(field_path)}`: "
202
+ f"declared `{dst_type_info.core_type}`, a list type expected"
203
+ )
167
204
  expected_dim = (
168
- dst_type_info.vector_info.dim if dst_type_info.vector_info else None
205
+ dst_type_variant.vector_info.dim
206
+ if dst_type_variant and dst_type_variant.vector_info
207
+ else None
169
208
  )
170
209
 
171
- elem_decoder = None
210
+ vec_elem_decoder = None
172
211
  scalar_dtype = None
173
- if dst_type_info.np_number_type is None: # for Non-NDArray vector
174
- elem_decoder = make_engine_value_decoder(
212
+ if dst_type_variant and dst_type_info.base_type is np.ndarray:
213
+ if is_numpy_number_type(dst_type_variant.elem_type):
214
+ scalar_dtype = dst_type_variant.elem_type
215
+ else:
216
+ vec_elem_decoder = make_engine_value_decoder(
175
217
  field_path + ["[*]"],
176
218
  src_type["element_type"],
177
- dst_type_info.elem_type,
219
+ dst_type_variant and dst_type_variant.elem_type,
178
220
  )
179
- else: # for NDArray vector
180
- scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
181
- _ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
182
221
 
183
222
  def decode_vector(value: Any) -> Any | None:
184
223
  if value is None:
@@ -197,54 +236,33 @@ def make_engine_value_decoder(
197
236
  f"expected {expected_dim}, got {len(value)}"
198
237
  )
199
238
 
200
- if elem_decoder is not None: # for Non-NDArray vector
201
- return [elem_decoder(v) for v in value]
239
+ if vec_elem_decoder is not None: # for Non-NDArray vector
240
+ return [vec_elem_decoder(v) for v in value]
202
241
  else: # for NDArray vector
203
242
  return np.array(value, dtype=scalar_dtype)
204
243
 
205
244
  return decode_vector
206
245
 
207
- if dst_type_info.struct_type is not None:
208
- return _make_engine_struct_value_decoder(
209
- field_path, src_type["fields"], dst_type_info.struct_type
210
- )
211
-
212
- if src_type_kind in TABLE_TYPES:
213
- field_path.append("[*]")
214
- elem_type_info = analyze_type_info(dst_type_info.elem_type)
215
- if elem_type_info.struct_type is None:
246
+ if isinstance(dst_type_variant, AnalyzedBasicType):
247
+ if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
216
248
  raise ValueError(
217
249
  f"Type mismatch for `{''.join(field_path)}`: "
218
- f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected"
219
- )
220
- engine_fields_schema = src_type["row"]["fields"]
221
- if elem_type_info.key_type is not None:
222
- key_field_schema = engine_fields_schema[0]
223
- field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
224
- key_decoder = make_engine_value_decoder(
225
- field_path, key_field_schema["type"], elem_type_info.key_type
226
- )
227
- field_path.pop()
228
- value_decoder = _make_engine_struct_value_decoder(
229
- field_path, engine_fields_schema[1:], elem_type_info.struct_type
250
+ f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})"
230
251
  )
231
252
 
232
- def decode(value: Any) -> Any | None:
233
- if value is None:
234
- return None
235
- return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
236
- else:
237
- elem_decoder = _make_engine_struct_value_decoder(
238
- field_path, engine_fields_schema, elem_type_info.struct_type
239
- )
253
+ if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
254
+ dst_core_type = dst_type_info.core_type
240
255
 
241
- def decode(value: Any) -> Any | None:
256
+ def decode_scalar(value: Any) -> Any | None:
242
257
  if value is None:
243
- return None
244
- return [elem_decoder(v) for v in value]
258
+ if dst_type_info.nullable:
259
+ return None
260
+ raise ValueError(
261
+ f"Received null for non-nullable scalar `{''.join(field_path)}`"
262
+ )
263
+ return dst_core_type(value)
245
264
 
246
- field_path.pop()
247
- return decode
265
+ return decode_scalar
248
266
 
249
267
  return lambda value: value
250
268
 
@@ -252,11 +270,36 @@ def make_engine_value_decoder(
252
270
  def _make_engine_struct_value_decoder(
253
271
  field_path: list[str],
254
272
  src_fields: list[dict[str, Any]],
255
- dst_struct_type: type,
273
+ dst_type_info: AnalyzedTypeInfo,
256
274
  ) -> Callable[[list[Any]], Any]:
257
275
  """Make a decoder from an engine field values to a Python value."""
258
276
 
277
+ dst_type_variant = dst_type_info.variant
278
+
279
+ use_dict = False
280
+ if isinstance(dst_type_variant, AnalyzedAnyType):
281
+ use_dict = True
282
+ elif isinstance(dst_type_variant, AnalyzedDictType):
283
+ analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
284
+ analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
285
+ use_dict = (
286
+ isinstance(analyzed_key_type.variant, AnalyzedAnyType)
287
+ or (
288
+ isinstance(analyzed_key_type.variant, AnalyzedBasicType)
289
+ and analyzed_key_type.variant.kind == "Str"
290
+ )
291
+ ) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
292
+ if use_dict:
293
+ return _make_engine_struct_to_dict_decoder(field_path, src_fields)
294
+
295
+ if not isinstance(dst_type_variant, AnalyzedStructType):
296
+ raise ValueError(
297
+ f"Type mismatch for `{''.join(field_path)}`: "
298
+ f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
299
+ )
300
+
259
301
  src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
302
+ dst_struct_type = dst_type_variant.struct_type
260
303
 
261
304
  parameters: Mapping[str, inspect.Parameter]
262
305
  if dataclasses.is_dataclass(dst_struct_type):
cocoindex/flow.py CHANGED
@@ -10,6 +10,13 @@ import functools
10
10
  import inspect
11
11
  import re
12
12
 
13
+ from .validation import (
14
+ validate_flow_name,
15
+ NamingError,
16
+ validate_full_flow_name,
17
+ validate_target_name,
18
+ )
19
+
13
20
  from dataclasses import dataclass
14
21
  from enum import Enum
15
22
  from threading import Lock
@@ -300,6 +307,9 @@ class DataScope:
300
307
  )
301
308
 
302
309
  def __setitem__(self, field_name: str, value: DataSlice[T]) -> None:
310
+ from .validation import validate_field_name
311
+
312
+ validate_field_name(field_name)
303
313
  value._state.attach_to_scope(self._engine_data_scope, field_name)
304
314
 
305
315
  def __enter__(self) -> DataScope:
@@ -367,7 +377,7 @@ class DataCollector:
367
377
 
368
378
  def export(
369
379
  self,
370
- name: str,
380
+ target_name: str,
371
381
  target_spec: op.TargetSpec,
372
382
  /,
373
383
  *,
@@ -381,6 +391,8 @@ class DataCollector:
381
391
 
382
392
  `vector_index` is for backward compatibility only. Please use `vector_indexes` instead.
383
393
  """
394
+
395
+ validate_target_name(target_name)
384
396
  if not isinstance(target_spec, op.TargetSpec):
385
397
  raise ValueError(
386
398
  "export() can only be called on a CocoIndex target storage"
@@ -398,7 +410,7 @@ class DataCollector:
398
410
  vector_indexes=vector_indexes,
399
411
  )
400
412
  self._flow_builder_state.engine_flow_builder.export(
401
- name,
413
+ target_name,
402
414
  _spec_kind(target_spec),
403
415
  dump_engine_object(target_spec),
404
416
  dump_engine_object(index_options),
@@ -660,6 +672,8 @@ class Flow:
660
672
  def __init__(
661
673
  self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow]
662
674
  ):
675
+ validate_flow_name(name)
676
+ validate_full_flow_name(full_name)
663
677
  self._name = name
664
678
  self._full_name = full_name
665
679
  engine_flow = None
@@ -831,11 +845,6 @@ def get_flow_full_name(name: str) -> str:
831
845
 
832
846
 
833
847
  def add_flow_def(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
834
- """Add a flow definition to the cocoindex library."""
835
- if not all(c.isalnum() or c == "_" for c in name):
836
- raise ValueError(
837
- f"Flow name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed."
838
- )
839
848
  with _flows_lock:
840
849
  if name in _flows:
841
850
  raise KeyError(f"Flow with name {name} already exists")
cocoindex/setting.py CHANGED
@@ -6,6 +6,7 @@ import os
6
6
 
7
7
  from typing import Callable, Self, Any, overload
8
8
  from dataclasses import dataclass
9
+ from .validation import validate_app_namespace_name
9
10
 
10
11
  _app_namespace: str = ""
11
12
 
@@ -27,6 +28,8 @@ def split_app_namespace(full_name: str, delimiter: str) -> tuple[str, str]:
27
28
 
28
29
  def set_app_namespace(app_namespace: str) -> None:
29
30
  """Set the application namespace."""
31
+ if app_namespace:
32
+ validate_app_namespace_name(app_namespace)
30
33
  global _app_namespace # pylint: disable=global-statement
31
34
  _app_namespace = app_namespace
32
35
 
@@ -105,7 +105,9 @@ def validate_full_roundtrip_to(
105
105
  for other_value, other_type in decoded_values:
106
106
  decoder = make_engine_value_decoder([], encoded_output_type, other_type)
107
107
  other_decoded_value = decoder(value_from_engine)
108
- assert eq(other_decoded_value, other_value)
108
+ assert eq(other_decoded_value, other_value), (
109
+ f"Expected {other_value} but got {other_decoded_value} for {other_type}"
110
+ )
109
111
 
110
112
 
111
113
  def validate_full_roundtrip(
@@ -1096,6 +1098,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
1096
1098
  validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
1097
1099
 
1098
1100
 
1101
+ def test_full_roundtrip_vector_of_vector() -> None:
1102
+ """Test full roundtrip for vector of vector."""
1103
+ value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1104
+ validate_full_roundtrip(
1105
+ value_f32,
1106
+ Vector[Vector[np.float32, Literal[3]], Literal[2]],
1107
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
1108
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
1109
+ (
1110
+ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1111
+ list[Vector[cocoindex.Float32, Literal[3]]],
1112
+ ),
1113
+ (
1114
+ value_f32,
1115
+ np.typing.NDArray[np.typing.NDArray[np.float32]],
1116
+ ),
1117
+ )
1118
+
1119
+
1099
1120
  def test_full_roundtrip_vector_other_types() -> None:
1100
1121
  """Test full roundtrip for Vector with non-numeric basic types."""
1101
1122
  uuid_list = [uuid.uuid4(), uuid.uuid4()]
@@ -1216,7 +1237,7 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
1216
1237
  numpy_float: np.float64
1217
1238
  python_float: float
1218
1239
  string: str
1219
- annotated_int: Annotated[np.int64, TypeKind("int")]
1240
+ annotated_int: Annotated[np.int64, TypeKind("Int64")]
1220
1241
  annotated_float: Float32
1221
1242
 
1222
1243
  instance = MixedStruct(