cocoindex 0.1.72__cp311-cp311-win_amd64.whl → 0.1.74__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
cocoindex/convert.py CHANGED
@@ -2,9 +2,12 @@
2
2
  Utilities to convert between Python and engine values.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import dataclasses
6
8
  import datetime
7
9
  import inspect
10
+ import warnings
8
11
  from enum import Enum
9
12
  from typing import Any, Callable, Mapping, get_origin
10
13
 
@@ -13,15 +16,40 @@ import numpy as np
13
16
  from .typing import (
14
17
  KEY_FIELD_NAME,
15
18
  TABLE_TYPES,
16
- DtypeRegistry,
17
19
  analyze_type_info,
18
20
  encode_enriched_type,
19
- extract_ndarray_scalar_dtype,
20
21
  is_namedtuple_type,
21
22
  is_struct_type,
23
+ AnalyzedTypeInfo,
24
+ AnalyzedAnyType,
25
+ AnalyzedDictType,
26
+ AnalyzedListType,
27
+ AnalyzedBasicType,
28
+ AnalyzedUnionType,
29
+ AnalyzedUnknownType,
30
+ AnalyzedStructType,
31
+ is_numpy_number_type,
22
32
  )
23
33
 
24
34
 
35
+ class ChildFieldPath:
36
+ """Context manager to append a field to field_path on enter and pop it on exit."""
37
+
38
+ _field_path: list[str]
39
+ _field_name: str
40
+
41
+ def __init__(self, field_path: list[str], field_name: str):
42
+ self._field_path: list[str] = field_path
43
+ self._field_name = field_name
44
+
45
+ def __enter__(self) -> ChildFieldPath:
46
+ self._field_path.append(self._field_name)
47
+ return self
48
+
49
+ def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
50
+ self._field_path.pop()
51
+
52
+
25
53
  def encode_engine_value(value: Any) -> Any:
26
54
  """Encode a Python value to an engine value."""
27
55
  if dataclasses.is_dataclass(value):
@@ -66,7 +94,7 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
66
94
  def make_engine_value_decoder(
67
95
  field_path: list[str],
68
96
  src_type: dict[str, Any],
69
- dst_annotation: Any,
97
+ dst_type_info: AnalyzedTypeInfo,
70
98
  ) -> Callable[[Any], Any]:
71
99
  """
72
100
  Make a decoder from an engine value to a Python value.
@@ -79,106 +107,138 @@ def make_engine_value_decoder(
79
107
  Returns:
80
108
  A decoder from an engine value to a Python value.
81
109
  """
110
+
82
111
  src_type_kind = src_type["kind"]
83
112
 
84
- dst_is_any = (
85
- dst_annotation is None
86
- or dst_annotation is inspect.Parameter.empty
87
- or dst_annotation is Any
88
- )
89
- if dst_is_any:
90
- if src_type_kind == "Union":
91
- return lambda value: value[1]
92
- if src_type_kind == "Struct":
93
- return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
94
- if src_type_kind in TABLE_TYPES:
113
+ dst_type_variant = dst_type_info.variant
114
+
115
+ if isinstance(dst_type_variant, AnalyzedUnknownType):
116
+ raise ValueError(
117
+ f"Type mismatch for `{''.join(field_path)}`: "
118
+ f"declared `{dst_type_info.core_type}`, an unsupported type"
119
+ )
120
+
121
+ if src_type_kind == "Struct":
122
+ return make_engine_struct_decoder(
123
+ field_path,
124
+ src_type["fields"],
125
+ dst_type_info,
126
+ )
127
+
128
+ if src_type_kind in TABLE_TYPES:
129
+ with ChildFieldPath(field_path, "[*]"):
130
+ engine_fields_schema = src_type["row"]["fields"]
131
+
95
132
  if src_type_kind == "LTable":
96
- return _make_engine_ltable_to_list_dict_decoder(
97
- field_path, src_type["row"]["fields"]
133
+ if isinstance(dst_type_variant, AnalyzedAnyType):
134
+ return _make_engine_ltable_to_list_dict_decoder(
135
+ field_path, engine_fields_schema
136
+ )
137
+ if not isinstance(dst_type_variant, AnalyzedListType):
138
+ raise ValueError(
139
+ f"Type mismatch for `{''.join(field_path)}`: "
140
+ f"declared `{dst_type_info.core_type}`, a list type expected"
141
+ )
142
+ row_decoder = make_engine_struct_decoder(
143
+ field_path,
144
+ engine_fields_schema,
145
+ analyze_type_info(dst_type_variant.elem_type),
98
146
  )
147
+
148
+ def decode(value: Any) -> Any | None:
149
+ if value is None:
150
+ return None
151
+ return [row_decoder(v) for v in value]
152
+
99
153
  elif src_type_kind == "KTable":
100
- return _make_engine_ktable_to_dict_dict_decoder(
101
- field_path, src_type["row"]["fields"]
154
+ if isinstance(dst_type_variant, AnalyzedAnyType):
155
+ return _make_engine_ktable_to_dict_dict_decoder(
156
+ field_path, engine_fields_schema
157
+ )
158
+ if not isinstance(dst_type_variant, AnalyzedDictType):
159
+ raise ValueError(
160
+ f"Type mismatch for `{''.join(field_path)}`: "
161
+ f"declared `{dst_type_info.core_type}`, a dict type expected"
162
+ )
163
+
164
+ key_field_schema = engine_fields_schema[0]
165
+ field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
166
+ key_decoder = make_engine_value_decoder(
167
+ field_path,
168
+ key_field_schema["type"],
169
+ analyze_type_info(dst_type_variant.key_type),
170
+ )
171
+ field_path.pop()
172
+ value_decoder = make_engine_struct_decoder(
173
+ field_path,
174
+ engine_fields_schema[1:],
175
+ analyze_type_info(dst_type_variant.value_type),
102
176
  )
103
- return lambda value: value
104
177
 
105
- # Handle struct -> dict binding for explicit dict annotations
106
- is_dict_annotation = False
107
- if dst_annotation is dict:
108
- is_dict_annotation = True
109
- elif getattr(dst_annotation, "__origin__", None) is dict:
110
- args = getattr(dst_annotation, "__args__", ())
111
- if args == (str, Any):
112
- is_dict_annotation = True
113
- if is_dict_annotation and src_type_kind == "Struct":
114
- return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
178
+ def decode(value: Any) -> Any | None:
179
+ if value is None:
180
+ return None
181
+ return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
115
182
 
116
- dst_type_info = analyze_type_info(dst_annotation)
183
+ return decode
117
184
 
118
185
  if src_type_kind == "Union":
119
- dst_type_variants = (
120
- dst_type_info.union_variant_types
121
- if dst_type_info.union_variant_types is not None
122
- else [dst_annotation]
186
+ if isinstance(dst_type_variant, AnalyzedAnyType):
187
+ return lambda value: value[1]
188
+
189
+ dst_type_info_variants = (
190
+ [analyze_type_info(t) for t in dst_type_variant.variant_types]
191
+ if isinstance(dst_type_variant, AnalyzedUnionType)
192
+ else [dst_type_info]
123
193
  )
124
194
  src_type_variants = src_type["types"]
125
195
  decoders = []
126
196
  for i, src_type_variant in enumerate(src_type_variants):
127
- src_field_path = field_path + [f"[{i}]"]
128
- decoder = None
129
- for dst_type_variant in dst_type_variants:
130
- try:
131
- decoder = make_engine_value_decoder(
132
- src_field_path, src_type_variant, dst_type_variant
197
+ with ChildFieldPath(field_path, f"[{i}]"):
198
+ decoder = None
199
+ for dst_type_info_variant in dst_type_info_variants:
200
+ try:
201
+ decoder = make_engine_value_decoder(
202
+ field_path, src_type_variant, dst_type_info_variant
203
+ )
204
+ break
205
+ except ValueError:
206
+ pass
207
+ if decoder is None:
208
+ raise ValueError(
209
+ f"Type mismatch for `{''.join(field_path)}`: "
210
+ f"cannot find matched target type for source type variant {src_type_variant}"
133
211
  )
134
- break
135
- except ValueError:
136
- pass
137
- if decoder is None:
138
- raise ValueError(
139
- f"Type mismatch for `{''.join(field_path)}`: "
140
- f"cannot find matched target type for source type variant {src_type_variant}"
141
- )
142
- decoders.append(decoder)
212
+ decoders.append(decoder)
143
213
  return lambda value: decoders[value[0]](value[1])
144
214
 
145
- if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
146
- raise ValueError(
147
- f"Type mismatch for `{''.join(field_path)}`: "
148
- f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
149
- )
150
-
151
- if dst_type_info.kind in ("Float32", "Float64", "Int64"):
152
- dst_core_type = dst_type_info.core_type
153
-
154
- def decode_scalar(value: Any) -> Any | None:
155
- if value is None:
156
- if dst_type_info.nullable:
157
- return None
158
- raise ValueError(
159
- f"Received null for non-nullable scalar `{''.join(field_path)}`"
160
- )
161
- return dst_core_type(value)
162
-
163
- return decode_scalar
215
+ if isinstance(dst_type_variant, AnalyzedAnyType):
216
+ return lambda value: value
164
217
 
165
218
  if src_type_kind == "Vector":
166
219
  field_path_str = "".join(field_path)
220
+ if not isinstance(dst_type_variant, AnalyzedListType):
221
+ raise ValueError(
222
+ f"Type mismatch for `{''.join(field_path)}`: "
223
+ f"declared `{dst_type_info.core_type}`, a list type expected"
224
+ )
167
225
  expected_dim = (
168
- dst_type_info.vector_info.dim if dst_type_info.vector_info else None
226
+ dst_type_variant.vector_info.dim
227
+ if dst_type_variant and dst_type_variant.vector_info
228
+ else None
169
229
  )
170
230
 
171
- elem_decoder = None
231
+ vec_elem_decoder = None
172
232
  scalar_dtype = None
173
- if dst_type_info.np_number_type is None: # for Non-NDArray vector
174
- elem_decoder = make_engine_value_decoder(
233
+ if dst_type_variant and dst_type_info.base_type is np.ndarray:
234
+ if is_numpy_number_type(dst_type_variant.elem_type):
235
+ scalar_dtype = dst_type_variant.elem_type
236
+ else:
237
+ vec_elem_decoder = make_engine_value_decoder(
175
238
  field_path + ["[*]"],
176
239
  src_type["element_type"],
177
- dst_type_info.elem_type,
240
+ analyze_type_info(dst_type_variant and dst_type_variant.elem_type),
178
241
  )
179
- else: # for NDArray vector
180
- scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
181
- _ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
182
242
 
183
243
  def decode_vector(value: Any) -> Any | None:
184
244
  if value is None:
@@ -197,66 +257,94 @@ def make_engine_value_decoder(
197
257
  f"expected {expected_dim}, got {len(value)}"
198
258
  )
199
259
 
200
- if elem_decoder is not None: # for Non-NDArray vector
201
- return [elem_decoder(v) for v in value]
260
+ if vec_elem_decoder is not None: # for Non-NDArray vector
261
+ return [vec_elem_decoder(v) for v in value]
202
262
  else: # for NDArray vector
203
263
  return np.array(value, dtype=scalar_dtype)
204
264
 
205
265
  return decode_vector
206
266
 
207
- if dst_type_info.struct_type is not None:
208
- return _make_engine_struct_value_decoder(
209
- field_path, src_type["fields"], dst_type_info.struct_type
210
- )
211
-
212
- if src_type_kind in TABLE_TYPES:
213
- field_path.append("[*]")
214
- elem_type_info = analyze_type_info(dst_type_info.elem_type)
215
- if elem_type_info.struct_type is None:
267
+ if isinstance(dst_type_variant, AnalyzedBasicType):
268
+ if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
216
269
  raise ValueError(
217
270
  f"Type mismatch for `{''.join(field_path)}`: "
218
- f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected"
219
- )
220
- engine_fields_schema = src_type["row"]["fields"]
221
- if elem_type_info.key_type is not None:
222
- key_field_schema = engine_fields_schema[0]
223
- field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
224
- key_decoder = make_engine_value_decoder(
225
- field_path, key_field_schema["type"], elem_type_info.key_type
226
- )
227
- field_path.pop()
228
- value_decoder = _make_engine_struct_value_decoder(
229
- field_path, engine_fields_schema[1:], elem_type_info.struct_type
271
+ f"passed in {src_type_kind}, declared {dst_type_info.core_type} ({dst_type_variant.kind})"
230
272
  )
231
273
 
232
- def decode(value: Any) -> Any | None:
233
- if value is None:
234
- return None
235
- return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
236
- else:
237
- elem_decoder = _make_engine_struct_value_decoder(
238
- field_path, engine_fields_schema, elem_type_info.struct_type
239
- )
274
+ if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
275
+ dst_core_type = dst_type_info.core_type
240
276
 
241
- def decode(value: Any) -> Any | None:
277
+ def decode_scalar(value: Any) -> Any | None:
242
278
  if value is None:
243
- return None
244
- return [elem_decoder(v) for v in value]
279
+ if dst_type_info.nullable:
280
+ return None
281
+ raise ValueError(
282
+ f"Received null for non-nullable scalar `{''.join(field_path)}`"
283
+ )
284
+ return dst_core_type(value)
245
285
 
246
- field_path.pop()
247
- return decode
286
+ return decode_scalar
248
287
 
249
288
  return lambda value: value
250
289
 
251
290
 
252
- def _make_engine_struct_value_decoder(
291
+ def _get_auto_default_for_type(
292
+ type_info: AnalyzedTypeInfo,
293
+ ) -> tuple[Any, bool]:
294
+ """
295
+ Get an auto-default value for a type annotation if it's safe to do so.
296
+
297
+ Returns:
298
+ A tuple of (default_value, is_supported) where:
299
+ - default_value: The default value if auto-defaulting is supported
300
+ - is_supported: True if auto-defaulting is supported for this type
301
+ """
302
+ # Case 1: Nullable types (Optional[T] or T | None)
303
+ if type_info.nullable:
304
+ return None, True
305
+
306
+ # Case 2: Table types (KTable or LTable) - check if it's a list or dict type
307
+ if isinstance(type_info.variant, AnalyzedListType):
308
+ return [], True
309
+ elif isinstance(type_info.variant, AnalyzedDictType):
310
+ return {}, True
311
+
312
+ return None, False
313
+
314
+
315
+ def make_engine_struct_decoder(
253
316
  field_path: list[str],
254
317
  src_fields: list[dict[str, Any]],
255
- dst_struct_type: type,
318
+ dst_type_info: AnalyzedTypeInfo,
256
319
  ) -> Callable[[list[Any]], Any]:
257
320
  """Make a decoder from an engine field values to a Python value."""
258
321
 
322
+ dst_type_variant = dst_type_info.variant
323
+
324
+ use_dict = False
325
+ if isinstance(dst_type_variant, AnalyzedAnyType):
326
+ use_dict = True
327
+ elif isinstance(dst_type_variant, AnalyzedDictType):
328
+ analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
329
+ analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
330
+ use_dict = (
331
+ isinstance(analyzed_key_type.variant, AnalyzedAnyType)
332
+ or (
333
+ isinstance(analyzed_key_type.variant, AnalyzedBasicType)
334
+ and analyzed_key_type.variant.kind == "Str"
335
+ )
336
+ ) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
337
+ if use_dict:
338
+ return _make_engine_struct_to_dict_decoder(field_path, src_fields)
339
+
340
+ if not isinstance(dst_type_variant, AnalyzedStructType):
341
+ raise ValueError(
342
+ f"Type mismatch for `{''.join(field_path)}`: "
343
+ f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
344
+ )
345
+
259
346
  src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
347
+ dst_struct_type = dst_type_variant.struct_type
260
348
 
261
349
  parameters: Mapping[str, inspect.Parameter]
262
350
  if dataclasses.is_dataclass(dst_struct_type):
@@ -278,32 +366,39 @@ def _make_engine_struct_value_decoder(
278
366
  else:
279
367
  raise ValueError(f"Unsupported struct type: {dst_struct_type}")
280
368
 
281
- def make_closure_for_value(
369
+ def make_closure_for_field(
282
370
  name: str, param: inspect.Parameter
283
371
  ) -> Callable[[list[Any]], Any]:
284
372
  src_idx = src_name_to_idx.get(name)
285
- if src_idx is not None:
286
- field_path.append(f".{name}")
287
- field_decoder = make_engine_value_decoder(
288
- field_path, src_fields[src_idx]["type"], param.annotation
289
- )
290
- field_path.pop()
291
- return (
292
- lambda values: field_decoder(values[src_idx])
293
- if len(values) > src_idx
294
- else param.default
295
- )
373
+ type_info = analyze_type_info(param.annotation)
374
+
375
+ with ChildFieldPath(field_path, f".{name}"):
376
+ if src_idx is not None:
377
+ field_decoder = make_engine_value_decoder(
378
+ field_path, src_fields[src_idx]["type"], type_info
379
+ )
380
+ return lambda values: field_decoder(values[src_idx])
381
+
382
+ default_value = param.default
383
+ if default_value is not inspect.Parameter.empty:
384
+ return lambda _: default_value
385
+
386
+ auto_default, is_supported = _get_auto_default_for_type(type_info)
387
+ if is_supported:
388
+ warnings.warn(
389
+ f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
390
+ f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
391
+ UserWarning,
392
+ stacklevel=4,
393
+ )
394
+ return lambda _: auto_default
296
395
 
297
- default_value = param.default
298
- if default_value is inspect.Parameter.empty:
299
396
  raise ValueError(
300
- f"Field without default value is missing in input: {''.join(field_path)}"
397
+ f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
301
398
  )
302
399
 
303
- return lambda _: default_value
304
-
305
400
  field_value_decoder = [
306
- make_closure_for_value(name, param) for (name, param) in parameters.items()
401
+ make_closure_for_field(name, param) for (name, param) in parameters.items()
307
402
  ]
308
403
 
309
404
  return lambda values: dst_struct_type(
@@ -320,13 +415,12 @@ def _make_engine_struct_to_dict_decoder(
320
415
  field_decoders = []
321
416
  for i, field_schema in enumerate(src_fields):
322
417
  field_name = field_schema["name"]
323
- field_path.append(f".{field_name}")
324
- field_decoder = make_engine_value_decoder(
325
- field_path,
326
- field_schema["type"],
327
- Any, # Use Any for recursive decoding
328
- )
329
- field_path.pop()
418
+ with ChildFieldPath(field_path, f".{field_name}"):
419
+ field_decoder = make_engine_value_decoder(
420
+ field_path,
421
+ field_schema["type"],
422
+ analyze_type_info(Any), # Use Any for recursive decoding
423
+ )
330
424
  field_decoders.append((field_name, field_decoder))
331
425
 
332
426
  def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
@@ -383,9 +477,10 @@ def _make_engine_ktable_to_dict_dict_decoder(
383
477
  value_fields_schema = src_fields[1:]
384
478
 
385
479
  # Create decoders
386
- field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
387
- key_decoder = make_engine_value_decoder(field_path, key_field_schema["type"], Any)
388
- field_path.pop()
480
+ with ChildFieldPath(field_path, f".{key_field_schema.get('name', KEY_FIELD_NAME)}"):
481
+ key_decoder = make_engine_value_decoder(
482
+ field_path, key_field_schema["type"], analyze_type_info(Any)
483
+ )
389
484
 
390
485
  value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)
391
486
 
cocoindex/flow.py CHANGED
@@ -16,6 +16,7 @@ from .validation import (
16
16
  validate_full_flow_name,
17
17
  validate_target_name,
18
18
  )
19
+ from .typing import analyze_type_info
19
20
 
20
21
  from dataclasses import dataclass
21
22
  from enum import Enum
@@ -1053,7 +1054,7 @@ class TransformFlow(Generic[T]):
1053
1054
  sig.return_annotation
1054
1055
  )
1055
1056
  result_decoder = make_engine_value_decoder(
1056
- [], engine_return_type["type"], python_return_type
1057
+ [], engine_return_type["type"], analyze_type_info(python_return_type)
1057
1058
  )
1058
1059
 
1059
1060
  return TransformFlowInfo(engine_flow, result_decoder)