cocoindex 0.1.43__cp311-cp311-macosx_11_0_arm64.whl → 0.1.45__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/flow.py CHANGED
@@ -10,14 +10,24 @@ import inspect
10
10
  import datetime
11
11
  import functools
12
12
 
13
- from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple
13
+ from typing import (
14
+ Any,
15
+ Callable,
16
+ Sequence,
17
+ TypeVar,
18
+ Generic,
19
+ get_args,
20
+ get_origin,
21
+ NamedTuple,
22
+ cast,
23
+ )
14
24
  from threading import Lock
15
25
  from enum import Enum
16
26
  from dataclasses import dataclass
17
27
  from rich.text import Text
18
28
  from rich.tree import Tree
19
29
 
20
- from . import _engine
30
+ from . import _engine # type: ignore
21
31
  from . import index
22
32
  from . import op
23
33
  from . import setting
@@ -25,11 +35,12 @@ from .convert import dump_engine_object, encode_engine_value, make_engine_value_
25
35
  from .typing import encode_enriched_type
26
36
  from .runtime import execution_context
27
37
 
38
+
28
39
  class _NameBuilder:
29
40
  _existing_names: set[str]
30
41
  _next_name_index: dict[str, int]
31
42
 
32
- def __init__(self):
43
+ def __init__(self) -> None:
33
44
  self._existing_names = set()
34
45
  self._next_name_index = {}
35
46
 
@@ -51,40 +62,53 @@ class _NameBuilder:
51
62
  return name
52
63
 
53
64
 
54
- _WORD_BOUNDARY_RE = re.compile('(?<!^)(?=[A-Z])')
65
+ _WORD_BOUNDARY_RE = re.compile("(?<!^)(?=[A-Z])")
66
+
67
+
55
68
  def _to_snake_case(name: str) -> str:
56
- return _WORD_BOUNDARY_RE.sub('_', name).lower()
69
+ return _WORD_BOUNDARY_RE.sub("_", name).lower()
70
+
57
71
 
58
72
  def _create_data_slice(
59
- flow_builder_state: _FlowBuilderState,
60
- creator: Callable[[_engine.DataScopeRef | None, str | None], _engine.DataSlice],
61
- name: str | None = None) -> DataSlice:
73
+ flow_builder_state: _FlowBuilderState,
74
+ creator: Callable[[_engine.DataScopeRef | None, str | None], _engine.DataSlice],
75
+ name: str | None = None,
76
+ ) -> DataSlice[T]:
62
77
  if name is None:
63
- return DataSlice(_DataSliceState(
64
- flow_builder_state,
65
- lambda target:
66
- creator(target[0], target[1]) if target is not None else creator(None, None)))
78
+ return DataSlice(
79
+ _DataSliceState(
80
+ flow_builder_state,
81
+ lambda target: creator(target[0], target[1])
82
+ if target is not None
83
+ else creator(None, None),
84
+ )
85
+ )
67
86
  else:
68
87
  return DataSlice(_DataSliceState(flow_builder_state, creator(None, name)))
69
88
 
70
89
 
71
90
  def _spec_kind(spec: Any) -> str:
72
- return spec.__class__.__name__
91
+ return cast(str, spec.__class__.__name__)
92
+
93
+
94
+ T = TypeVar("T")
73
95
 
74
- T = TypeVar('T')
75
96
 
76
97
  class _DataSliceState:
77
98
  flow_builder_state: _FlowBuilderState
78
99
 
79
100
  _lazy_lock: Lock | None = None # None means it's not lazy.
80
101
  _data_slice: _engine.DataSlice | None = None
81
- _data_slice_creator: Callable[[tuple[_engine.DataScopeRef, str] | None],
82
- _engine.DataSlice] | None = None
102
+ _data_slice_creator: (
103
+ Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice] | None
104
+ ) = None
83
105
 
84
106
  def __init__(
85
- self, flow_builder_state: _FlowBuilderState,
86
- data_slice: _engine.DataSlice | Callable[[tuple[_engine.DataScopeRef, str] | None],
87
- _engine.DataSlice]):
107
+ self,
108
+ flow_builder_state: _FlowBuilderState,
109
+ data_slice: _engine.DataSlice
110
+ | Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice],
111
+ ):
88
112
  self.flow_builder_state = flow_builder_state
89
113
 
90
114
  if isinstance(data_slice, _engine.DataSlice):
@@ -124,6 +148,7 @@ class _DataSliceState:
124
148
  # TODO: We'll support this by an identity transformer or "aliasing" in the future.
125
149
  raise ValueError("DataSlice is already attached to a field")
126
150
 
151
+
127
152
  class DataSlice(Generic[T]):
128
153
  """A data slice represents a slice of data in a flow. It's readonly."""
129
154
 
@@ -132,13 +157,13 @@ class DataSlice(Generic[T]):
132
157
  def __init__(self, state: _DataSliceState):
133
158
  self._state = state
134
159
 
135
- def __str__(self):
160
+ def __str__(self) -> str:
136
161
  return str(self._state.engine_data_slice)
137
162
 
138
- def __repr__(self):
163
+ def __repr__(self) -> str:
139
164
  return repr(self._state.engine_data_slice)
140
165
 
141
- def __getitem__(self, field_name: str) -> DataSlice:
166
+ def __getitem__(self, field_name: str) -> DataSlice[T]:
142
167
  field_slice = self._state.engine_data_slice.field(field_name)
143
168
  if field_slice is None:
144
169
  raise KeyError(field_name)
@@ -158,7 +183,9 @@ class DataSlice(Generic[T]):
158
183
  with self.row() as scope:
159
184
  f(scope)
160
185
 
161
- def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
186
+ def transform(
187
+ self, fn_spec: op.FunctionSpec, *args: Any, **kwargs: Any
188
+ ) -> DataSlice[T]:
162
189
  """
163
190
  Apply a function to the data slice.
164
191
  """
@@ -167,63 +194,77 @@ class DataSlice(Generic[T]):
167
194
 
168
195
  transform_args: list[tuple[Any, str | None]]
169
196
  transform_args = [(self._state.engine_data_slice, None)]
170
- transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args]
171
- transform_args += [(self._state.flow_builder_state.get_data_slice(v), k)
172
- for (k, v) in kwargs.items()]
197
+ transform_args += [
198
+ (self._state.flow_builder_state.get_data_slice(v), None) for v in args
199
+ ]
200
+ transform_args += [
201
+ (self._state.flow_builder_state.get_data_slice(v), k)
202
+ for (k, v) in kwargs.items()
203
+ ]
173
204
 
174
205
  flow_builder_state = self._state.flow_builder_state
175
206
  return _create_data_slice(
176
207
  flow_builder_state,
177
- lambda target_scope, name:
178
- flow_builder_state.engine_flow_builder.transform(
179
- _spec_kind(fn_spec),
180
- dump_engine_object(fn_spec),
181
- transform_args,
182
- target_scope,
183
- flow_builder_state.field_name_builder.build_name(
184
- name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
185
- ))
186
-
187
- def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T:
208
+ lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
209
+ _spec_kind(fn_spec),
210
+ dump_engine_object(fn_spec),
211
+ transform_args,
212
+ target_scope,
213
+ flow_builder_state.field_name_builder.build_name(
214
+ name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
215
+ ),
216
+ ),
217
+ )
218
+
219
+ def call(self, func: Callable[[DataSlice[T]], T], *args: Any, **kwargs: Any) -> T:
188
220
  """
189
221
  Call a function with the data slice.
190
222
  """
191
223
  return func(self, *args, **kwargs)
192
224
 
193
- def _data_slice_state(data_slice: DataSlice) -> _DataSliceState:
225
+
226
+ def _data_slice_state(data_slice: DataSlice[T]) -> _DataSliceState:
194
227
  return data_slice._state # pylint: disable=protected-access
195
228
 
229
+
196
230
  class DataScope:
197
231
  """
198
232
  A data scope in a flow.
199
233
  It has multple fields and collectors, and allow users to add new fields and collectors.
200
234
  """
235
+
201
236
  _flow_builder_state: _FlowBuilderState
202
237
  _engine_data_scope: _engine.DataScopeRef
203
238
 
204
- def __init__(self, flow_builder_state: _FlowBuilderState, data_scope: _engine.DataScopeRef):
239
+ def __init__(
240
+ self, flow_builder_state: _FlowBuilderState, data_scope: _engine.DataScopeRef
241
+ ):
205
242
  self._flow_builder_state = flow_builder_state
206
243
  self._engine_data_scope = data_scope
207
244
 
208
- def __str__(self):
245
+ def __str__(self) -> str:
209
246
  return str(self._engine_data_scope)
210
247
 
211
- def __repr__(self):
248
+ def __repr__(self) -> str:
212
249
  return repr(self._engine_data_scope)
213
250
 
214
- def __getitem__(self, field_name: str) -> DataSlice:
215
- return DataSlice(_DataSliceState(
216
- self._flow_builder_state,
217
- self._flow_builder_state.engine_flow_builder.scope_field(
218
- self._engine_data_scope, field_name)))
251
+ def __getitem__(self, field_name: str) -> DataSlice[T]:
252
+ return DataSlice(
253
+ _DataSliceState(
254
+ self._flow_builder_state,
255
+ self._flow_builder_state.engine_flow_builder.scope_field(
256
+ self._engine_data_scope, field_name
257
+ ),
258
+ )
259
+ )
219
260
 
220
- def __setitem__(self, field_name: str, value: DataSlice):
261
+ def __setitem__(self, field_name: str, value: DataSlice[T]) -> None:
221
262
  value._state.attach_to_scope(self._engine_data_scope, field_name)
222
263
 
223
- def __enter__(self):
264
+ def __enter__(self) -> DataScope:
224
265
  return self
225
266
 
226
- def __exit__(self, exc_type, exc_value, traceback):
267
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
227
268
  del self._engine_data_scope
228
269
 
229
270
  def add_collector(self, name: str | None = None) -> DataCollector:
@@ -233,27 +274,36 @@ class DataScope:
233
274
  return DataCollector(
234
275
  self._flow_builder_state,
235
276
  self._engine_data_scope.add_collector(
236
- self._flow_builder_state.field_name_builder.build_name(name, prefix="_collector_")
237
- )
277
+ self._flow_builder_state.field_name_builder.build_name(
278
+ name, prefix="_collector_"
279
+ )
280
+ ),
238
281
  )
239
282
 
283
+
240
284
  class GeneratedField(Enum):
241
285
  """
242
286
  A generated field is automatically set by the engine.
243
287
  """
288
+
244
289
  UUID = "Uuid"
245
290
 
291
+
246
292
  class DataCollector:
247
293
  """A data collector is used to collect data into a collector."""
294
+
248
295
  _flow_builder_state: _FlowBuilderState
249
296
  _engine_data_collector: _engine.DataCollector
250
297
 
251
- def __init__(self, flow_builder_state: _FlowBuilderState,
252
- data_collector: _engine.DataCollector):
298
+ def __init__(
299
+ self,
300
+ flow_builder_state: _FlowBuilderState,
301
+ data_collector: _engine.DataCollector,
302
+ ):
253
303
  self._flow_builder_state = flow_builder_state
254
304
  self._engine_data_collector = data_collector
255
305
 
256
- def collect(self, **kwargs):
306
+ def collect(self, **kwargs: Any) -> None:
257
307
  """
258
308
  Collect data into the collector.
259
309
  """
@@ -268,45 +318,62 @@ class DataCollector:
268
318
  else:
269
319
  raise ValueError(f"Unexpected generated field: {v}")
270
320
  else:
271
- regular_kwargs.append(
272
- (k, self._flow_builder_state.get_data_slice(v)))
321
+ regular_kwargs.append((k, self._flow_builder_state.get_data_slice(v)))
273
322
 
274
323
  self._flow_builder_state.engine_flow_builder.collect(
275
- self._engine_data_collector, regular_kwargs, auto_uuid_field)
324
+ self._engine_data_collector, regular_kwargs, auto_uuid_field
325
+ )
276
326
 
277
- def export(self, name: str, target_spec: op.StorageSpec, /, *,
278
- primary_key_fields: Sequence[str],
279
- vector_indexes: Sequence[index.VectorIndexDef] = (),
280
- vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (),
281
- setup_by_user: bool = False):
327
+ def export(
328
+ self,
329
+ name: str,
330
+ target_spec: op.StorageSpec,
331
+ /,
332
+ *,
333
+ primary_key_fields: Sequence[str],
334
+ vector_indexes: Sequence[index.VectorIndexDef] = (),
335
+ vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (),
336
+ setup_by_user: bool = False,
337
+ ) -> None:
282
338
  """
283
339
  Export the collected data to the specified target.
284
340
 
285
341
  `vector_index` is for backward compatibility only. Please use `vector_indexes` instead.
286
342
  """
287
343
  if not isinstance(target_spec, op.StorageSpec):
288
- raise ValueError("export() can only be called on a CocoIndex target storage")
344
+ raise ValueError(
345
+ "export() can only be called on a CocoIndex target storage"
346
+ )
289
347
 
290
348
  # For backward compatibility only.
291
349
  if len(vector_indexes) == 0 and len(vector_index) > 0:
292
- vector_indexes = [index.VectorIndexDef(field_name=field_name, metric=metric)
293
- for field_name, metric in vector_index]
350
+ vector_indexes = [
351
+ index.VectorIndexDef(field_name=field_name, metric=metric)
352
+ for field_name, metric in vector_index
353
+ ]
294
354
 
295
355
  index_options = index.IndexOptions(
296
356
  primary_key_fields=primary_key_fields,
297
357
  vector_indexes=vector_indexes,
298
358
  )
299
359
  self._flow_builder_state.engine_flow_builder.export(
300
- name, _spec_kind(target_spec), dump_engine_object(target_spec),
301
- dump_engine_object(index_options), self._engine_data_collector, setup_by_user)
360
+ name,
361
+ _spec_kind(target_spec),
362
+ dump_engine_object(target_spec),
363
+ dump_engine_object(index_options),
364
+ self._engine_data_collector,
365
+ setup_by_user,
366
+ )
302
367
 
303
368
 
304
369
  _flow_name_builder = _NameBuilder()
305
370
 
371
+
306
372
  class _FlowBuilderState:
307
373
  """
308
374
  A flow builder is used to build a flow.
309
375
  """
376
+
310
377
  engine_flow_builder: _engine.FlowBuilder
311
378
  field_name_builder: _NameBuilder
312
379
 
@@ -322,32 +389,40 @@ class _FlowBuilderState:
322
389
  return v._state.engine_data_slice
323
390
  return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v)
324
391
 
392
+
325
393
  @dataclass
326
394
  class _SourceRefreshOptions:
327
395
  """
328
396
  Options for refreshing a source.
329
397
  """
398
+
330
399
  refresh_interval: datetime.timedelta | None = None
331
400
 
401
+
332
402
  class FlowBuilder:
333
403
  """
334
404
  A flow builder is used to build a flow.
335
405
  """
406
+
336
407
  _state: _FlowBuilderState
337
408
 
338
409
  def __init__(self, state: _FlowBuilderState):
339
410
  self._state = state
340
411
 
341
- def __str__(self):
412
+ def __str__(self) -> str:
342
413
  return str(self._state.engine_flow_builder)
343
414
 
344
- def __repr__(self):
415
+ def __repr__(self) -> str:
345
416
  return repr(self._state.engine_flow_builder)
346
417
 
347
- def add_source(self, spec: op.SourceSpec, /, *,
348
- name: str | None = None,
349
- refresh_interval: datetime.timedelta | None = None,
350
- ) -> DataSlice:
418
+ def add_source(
419
+ self,
420
+ spec: op.SourceSpec,
421
+ /,
422
+ *,
423
+ name: str | None = None,
424
+ refresh_interval: datetime.timedelta | None = None,
425
+ ) -> DataSlice[T]:
351
426
  """
352
427
  Import a source to the flow.
353
428
  """
@@ -360,30 +435,37 @@ class FlowBuilder:
360
435
  dump_engine_object(spec),
361
436
  target_scope,
362
437
  self._state.field_name_builder.build_name(
363
- name, prefix=_to_snake_case(_spec_kind(spec))+'_'),
364
- dump_engine_object(_SourceRefreshOptions(refresh_interval=refresh_interval)),
438
+ name, prefix=_to_snake_case(_spec_kind(spec)) + "_"
439
+ ),
440
+ dump_engine_object(
441
+ _SourceRefreshOptions(refresh_interval=refresh_interval)
442
+ ),
365
443
  ),
366
- name
444
+ name,
367
445
  )
368
446
 
369
- def declare(self, spec: op.DeclarationSpec):
447
+ def declare(self, spec: op.DeclarationSpec) -> None:
370
448
  """
371
449
  Add a declaration to the flow.
372
450
  """
373
451
  self._state.engine_flow_builder.declare(dump_engine_object(spec))
374
452
 
453
+
375
454
  @dataclass
376
455
  class FlowLiveUpdaterOptions:
377
456
  """
378
457
  Options for live updating a flow.
379
458
  """
459
+
380
460
  live_mode: bool = True
381
461
  print_stats: bool = False
382
462
 
463
+
383
464
  class FlowLiveUpdater:
384
465
  """
385
466
  A live updater for a flow.
386
467
  """
468
+
387
469
  _flow: Flow
388
470
  _options: FlowLiveUpdaterOptions
389
471
  _engine_live_updater: _engine.FlowLiveUpdater | None = None
@@ -396,7 +478,7 @@ class FlowLiveUpdater:
396
478
  self.start()
397
479
  return self
398
480
 
399
- def __exit__(self, exc_type, exc_value, traceback):
481
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
400
482
  self.abort()
401
483
  self.wait()
402
484
 
@@ -404,7 +486,7 @@ class FlowLiveUpdater:
404
486
  await self.start_async()
405
487
  return self
406
488
 
407
- async def __aexit__(self, exc_type, exc_value, traceback):
489
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
408
490
  self.abort()
409
491
  await self.wait_async()
410
492
 
@@ -419,7 +501,8 @@ class FlowLiveUpdater:
419
501
  Start the live updater.
420
502
  """
421
503
  self._engine_live_updater = await _engine.FlowLiveUpdater.create(
422
- await self._flow.internal_flow_async(), dump_engine_object(self._options))
504
+ await self._flow.internal_flow_async(), dump_engine_object(self._options)
505
+ )
423
506
 
424
507
  def wait(self) -> None:
425
508
  """
@@ -456,22 +539,28 @@ class EvaluateAndDumpOptions:
456
539
  """
457
540
  Options for evaluating and dumping a flow.
458
541
  """
542
+
459
543
  output_dir: str
460
544
  use_cache: bool = True
461
545
 
546
+
462
547
  class Flow:
463
548
  """
464
549
  A flow describes an indexing pipeline.
465
550
  """
551
+
466
552
  _name: str
467
553
  _full_name: str
468
554
  _lazy_engine_flow: Callable[[], _engine.Flow]
469
555
 
470
- def __init__(self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow]):
556
+ def __init__(
557
+ self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow]
558
+ ):
471
559
  self._name = name
472
560
  self._full_name = full_name
473
561
  engine_flow = None
474
562
  lock = Lock()
563
+
475
564
  def _lazy_engine_flow() -> _engine.Flow:
476
565
  nonlocal engine_flow, lock
477
566
  if engine_flow is None:
@@ -479,6 +568,7 @@ class Flow:
479
568
  if engine_flow is None:
480
569
  engine_flow = engine_flow_creator()
481
570
  return engine_flow
571
+
482
572
  self._lazy_engine_flow = _lazy_engine_flow
483
573
 
484
574
  def _render_spec(self, verbose: bool = False) -> Tree:
@@ -488,7 +578,7 @@ class Flow:
488
578
  spec = self._get_spec(verbose=verbose)
489
579
  tree = Tree(f"Flow: {self.full_name}", style="cyan")
490
580
 
491
- def build_tree(label: str, lines: list):
581
+ def build_tree(label: str, lines: list[Any]) -> Tree:
492
582
  node = Tree(label=label if lines else label + " None", style="cyan")
493
583
  for line in lines:
494
584
  child_node = node.add(Text(line.content, style="yellow"))
@@ -501,15 +591,17 @@ class Flow:
501
591
  return tree
502
592
 
503
593
  def _get_spec(self, verbose: bool = False) -> _engine.RenderedSpec:
504
- return self._lazy_engine_flow().get_spec(output_mode="verbose" if verbose else "concise")
505
-
594
+ return self._lazy_engine_flow().get_spec(
595
+ output_mode="verbose" if verbose else "concise"
596
+ )
597
+
506
598
  def _get_schema(self) -> list[tuple[str, str, str]]:
507
- return self._lazy_engine_flow().get_schema()
599
+ return cast(list[tuple[str, str, str]], self._lazy_engine_flow().get_schema())
508
600
 
509
- def __str__(self):
601
+ def __str__(self) -> str:
510
602
  return str(self._get_spec())
511
603
 
512
- def __repr__(self):
604
+ def __repr__(self) -> str:
513
605
  return repr(self._lazy_engine_flow())
514
606
 
515
607
  @property
@@ -538,11 +630,15 @@ class Flow:
538
630
  Update the index defined by the flow.
539
631
  Once the function returns, the index is fresh up to the moment when the function is called.
540
632
  """
541
- updater = await FlowLiveUpdater.create_async(self, FlowLiveUpdaterOptions(live_mode=False))
542
- await updater.wait_async()
633
+ async with FlowLiveUpdater(
634
+ self, FlowLiveUpdaterOptions(live_mode=False)
635
+ ) as updater:
636
+ await updater.wait_async()
543
637
  return updater.update_stats()
544
638
 
545
- def evaluate_and_dump(self, options: EvaluateAndDumpOptions):
639
+ def evaluate_and_dump(
640
+ self, options: EvaluateAndDumpOptions
641
+ ) -> _engine.IndexUpdateInfo:
546
642
  """
547
643
  Evaluate the flow and dump flow outputs to files.
548
644
  """
@@ -560,19 +656,26 @@ class Flow:
560
656
  """
561
657
  return await asyncio.to_thread(self.internal_flow)
562
658
 
563
- def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
659
+
660
+ def _create_lazy_flow(
661
+ name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]
662
+ ) -> Flow:
564
663
  """
565
664
  Create a flow without really building it yet.
566
665
  The flow will be built the first time when it's really needed.
567
666
  """
568
667
  flow_name = _flow_name_builder.build_name(name, prefix="_flow_")
569
668
  flow_full_name = get_full_flow_name(flow_name)
669
+
570
670
  def _create_engine_flow() -> _engine.Flow:
571
671
  flow_builder_state = _FlowBuilderState(flow_full_name)
572
672
  root_scope = DataScope(
573
- flow_builder_state, flow_builder_state.engine_flow_builder.root_scope())
673
+ flow_builder_state, flow_builder_state.engine_flow_builder.root_scope()
674
+ )
574
675
  fl_def(FlowBuilder(flow_builder_state), root_scope)
575
- return flow_builder_state.engine_flow_builder.build_flow(execution_context.event_loop)
676
+ return flow_builder_state.engine_flow_builder.build_flow(
677
+ execution_context.event_loop
678
+ )
576
679
 
577
680
  return Flow(flow_name, flow_full_name, _create_engine_flow)
578
681
 
@@ -580,28 +683,36 @@ def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope
580
683
  _flows_lock = Lock()
581
684
  _flows: dict[str, Flow] = {}
582
685
 
686
+
583
687
  def get_full_flow_name(name: str) -> str:
584
688
  """
585
689
  Get the full name of a flow.
586
690
  """
587
691
  return f"{setting.get_app_namespace(trailing_delimiter='.')}{name}"
588
692
 
693
+
589
694
  def add_flow_def(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
590
695
  """Add a flow definition to the cocoindex library."""
591
- if not all(c.isalnum() or c == '_' for c in name):
592
- raise ValueError(f"Flow name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed.")
696
+ if not all(c.isalnum() or c == "_" for c in name):
697
+ raise ValueError(
698
+ f"Flow name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed."
699
+ )
593
700
  with _flows_lock:
594
701
  if name in _flows:
595
702
  raise KeyError(f"Flow with name {name} already exists")
596
703
  fl = _flows[name] = _create_lazy_flow(name, fl_def)
597
704
  return fl
598
705
 
599
- def flow_def(name = None) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]:
706
+
707
+ def flow_def(
708
+ name: str | None = None,
709
+ ) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]:
600
710
  """
601
711
  A decorator to wrap the flow definition.
602
712
  """
603
713
  return lambda fl_def: add_flow_def(name or fl_def.__name__, fl_def)
604
714
 
715
+
605
716
  def flow_names() -> list[str]:
606
717
  """
607
718
  Get the names of all flows.
@@ -609,6 +720,7 @@ def flow_names() -> list[str]:
609
720
  with _flows_lock:
610
721
  return list(_flows.keys())
611
722
 
723
+
612
724
  def flows() -> dict[str, Flow]:
613
725
  """
614
726
  Get all flows.
@@ -616,6 +728,7 @@ def flows() -> dict[str, Flow]:
616
728
  with _flows_lock:
617
729
  return dict(_flows)
618
730
 
731
+
619
732
  def flow_by_name(name: str) -> Flow:
620
733
  """
621
734
  Get a flow by name.
@@ -623,12 +736,14 @@ def flow_by_name(name: str) -> Flow:
623
736
  with _flows_lock:
624
737
  return _flows[name]
625
738
 
739
+
626
740
  def ensure_all_flows_built() -> None:
627
741
  """
628
742
  Ensure all flows are built.
629
743
  """
630
744
  execution_context.run(ensure_all_flows_built_async())
631
745
 
746
+
632
747
  async def ensure_all_flows_built_async() -> None:
633
748
  """
634
749
  Ensure all flows are built.
@@ -636,43 +751,63 @@ async def ensure_all_flows_built_async() -> None:
636
751
  for fl in flows().values():
637
752
  await fl.internal_flow_async()
638
753
 
639
- def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
754
+
755
+ def update_all_flows(
756
+ options: FlowLiveUpdaterOptions,
757
+ ) -> dict[str, _engine.IndexUpdateInfo]:
640
758
  """
641
759
  Update all flows.
642
760
  """
643
- return execution_context.run(update_all_flows_async(options))
761
+ return cast(
762
+ dict[str, _engine.IndexUpdateInfo],
763
+ execution_context.run(update_all_flows_async(options)),
764
+ )
644
765
 
645
- async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
766
+
767
+ async def update_all_flows_async(
768
+ options: FlowLiveUpdaterOptions,
769
+ ) -> dict[str, _engine.IndexUpdateInfo]:
646
770
  """
647
771
  Update all flows.
648
772
  """
649
773
  await ensure_all_flows_built_async()
774
+
650
775
  async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInfo]:
651
776
  async with FlowLiveUpdater(fl, options) as updater:
652
777
  await updater.wait_async()
653
778
  return (name, updater.update_stats())
779
+
654
780
  fls = flows()
655
- all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items()))
781
+ all_stats = await asyncio.gather(
782
+ *(_update_flow(name, fl) for (name, fl) in fls.items())
783
+ )
656
784
  return dict(all_stats)
657
785
 
658
- def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None:
786
+
787
+ def _get_data_slice_annotation_type(
788
+ data_slice_type: type[DataSlice[T] | inspect._empty],
789
+ ) -> type[T] | None:
659
790
  type_args = get_args(data_slice_type)
660
791
  if data_slice_type is inspect.Parameter.empty or data_slice_type is DataSlice:
661
792
  return None
662
793
  if get_origin(data_slice_type) != DataSlice or len(type_args) != 1:
663
794
  raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}")
664
- return type_args[0]
795
+ return cast(type[T] | None, type_args[0])
796
+
665
797
 
666
798
  _transform_flow_name_builder = _NameBuilder()
667
799
 
800
+
668
801
  class TransformFlowInfo(NamedTuple):
669
802
  engine_flow: _engine.TransientFlow
670
803
  result_decoder: Callable[[Any], T]
671
804
 
805
+
672
806
  class TransformFlow(Generic[T]):
673
807
  """
674
808
  A transient transformation flow that transforms in-memory data.
675
809
  """
810
+
676
811
  _flow_fn: Callable[..., DataSlice[T]]
677
812
  _flow_name: str
678
813
  _flow_arg_types: list[Any]
@@ -682,21 +817,27 @@ class TransformFlow(Generic[T]):
682
817
  _lazy_flow_info: TransformFlowInfo | None = None
683
818
 
684
819
  def __init__(
685
- self, flow_fn: Callable[..., DataSlice[T]],
686
- flow_arg_types: Sequence[Any], /, name: str | None = None):
820
+ self,
821
+ flow_fn: Callable[..., DataSlice[T]],
822
+ flow_arg_types: Sequence[Any],
823
+ /,
824
+ name: str | None = None,
825
+ ):
687
826
  self._flow_fn = flow_fn
688
- self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_")
827
+ self._flow_name = _transform_flow_name_builder.build_name(
828
+ name, prefix="_transform_flow_"
829
+ )
689
830
  self._flow_arg_types = list(flow_arg_types)
690
831
  self._lazy_lock = asyncio.Lock()
691
832
 
692
- def __call__(self, *args, **kwargs) -> DataSlice[T]:
833
+ def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[T]:
693
834
  return self._flow_fn(*args, **kwargs)
694
835
 
695
836
  @property
696
837
  def _flow_info(self) -> TransformFlowInfo:
697
838
  if self._lazy_flow_info is not None:
698
839
  return self._lazy_flow_info
699
- return execution_context.run(self._flow_info_async())
840
+ return cast(TransformFlowInfo, execution_context.run(self._flow_info_async()))
700
841
 
701
842
  async def _flow_info_async(self) -> TransformFlowInfo:
702
843
  if self._lazy_flow_info is not None:
@@ -712,35 +853,57 @@ class TransformFlow(Generic[T]):
712
853
  if len(sig.parameters) != len(self._flow_arg_types):
713
854
  raise ValueError(
714
855
  f"Number of parameters in the flow function ({len(sig.parameters)}) "
715
- f"does not match the number of argument types ({len(self._flow_arg_types)})")
856
+ f"does not match the number of argument types ({len(self._flow_arg_types)})"
857
+ )
716
858
 
717
- kwargs: dict[str, DataSlice] = {}
718
- for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types):
719
- if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
720
- inspect.Parameter.KEYWORD_ONLY):
721
- raise ValueError(f"Parameter `{param_name}` is not a parameter can be passed by name")
859
+ kwargs: dict[str, DataSlice[T]] = {}
860
+ for (param_name, param), param_type in zip(
861
+ sig.parameters.items(), self._flow_arg_types
862
+ ):
863
+ if param.kind not in (
864
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
865
+ inspect.Parameter.KEYWORD_ONLY,
866
+ ):
867
+ raise ValueError(
868
+ f"Parameter `{param_name}` is not a parameter can be passed by name"
869
+ )
722
870
  encoded_type = encode_enriched_type(param_type)
723
871
  if encoded_type is None:
724
872
  raise ValueError(f"Parameter `{param_name}` has no type annotation")
725
- engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(param_name, encoded_type)
726
- kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))
873
+ engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
874
+ param_name, encoded_type
875
+ )
876
+ kwargs[param_name] = DataSlice(
877
+ _DataSliceState(flow_builder_state, engine_ds)
878
+ )
727
879
 
728
880
  output = self._flow_fn(**kwargs)
729
881
  flow_builder_state.engine_flow_builder.set_direct_output(
730
- _data_slice_state(output).engine_data_slice)
731
- engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop)
882
+ _data_slice_state(output).engine_data_slice
883
+ )
884
+ engine_flow = (
885
+ await flow_builder_state.engine_flow_builder.build_transient_flow_async(
886
+ execution_context.event_loop
887
+ )
888
+ )
732
889
  self._param_names = list(sig.parameters.keys())
733
890
 
734
- engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema()
735
- python_return_type = _get_data_slice_annotation_type(sig.return_annotation)
736
- result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type)
891
+ engine_return_type = (
892
+ _data_slice_state(output).engine_data_slice.data_type().schema()
893
+ )
894
+ python_return_type: type[T] | None = _get_data_slice_annotation_type(
895
+ sig.return_annotation
896
+ )
897
+ result_decoder = make_engine_value_decoder(
898
+ [], engine_return_type["type"], python_return_type
899
+ )
737
900
 
738
901
  return TransformFlowInfo(engine_flow, result_decoder)
739
902
 
740
- def __str__(self):
903
+ def __str__(self) -> str:
741
904
  return str(self._flow_info.engine_flow)
742
905
 
743
- def __repr__(self):
906
+ def __repr__(self) -> str:
744
907
  return repr(self._flow_info.engine_flow)
745
908
 
746
909
  def internal_flow(self) -> _engine.TransientFlow:
@@ -749,13 +912,13 @@ class TransformFlow(Generic[T]):
749
912
  """
750
913
  return self._flow_info.engine_flow
751
914
 
752
- def eval(self, *args, **kwargs) -> T:
915
+ def eval(self, *args: Any, **kwargs: Any) -> T:
753
916
  """
754
917
  Evaluate the transform flow.
755
918
  """
756
- return execution_context.run(self.eval_async(*args, **kwargs))
919
+ return cast(T, execution_context.run(self.eval_async(*args, **kwargs)))
757
920
 
758
- async def eval_async(self, *args, **kwargs) -> T:
921
+ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
759
922
  """
760
923
  Evaluate the transform flow.
761
924
  """
@@ -776,18 +939,26 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]
776
939
  """
777
940
  A decorator to wrap the transform function.
778
941
  """
779
- def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]):
942
+
943
+ def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]:
780
944
  sig = inspect.signature(fn)
781
945
  arg_types = []
782
- for (param_name, param) in sig.parameters.items():
783
- if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
784
- inspect.Parameter.KEYWORD_ONLY):
785
- raise ValueError(f"Parameter `{param_name}` is not a parameter can be passed by name")
786
- value_type_annotation = _get_data_slice_annotation_type(param.annotation)
946
+ for param_name, param in sig.parameters.items():
947
+ if param.kind not in (
948
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
949
+ inspect.Parameter.KEYWORD_ONLY,
950
+ ):
951
+ raise ValueError(
952
+ f"Parameter `{param_name}` is not a parameter can be passed by name"
953
+ )
954
+ value_type_annotation: type[T] | None = _get_data_slice_annotation_type(
955
+ param.annotation
956
+ )
787
957
  if value_type_annotation is None:
788
958
  raise ValueError(
789
959
  f"Parameter `{param_name}` for {fn} has no value type annotation. "
790
- "Please use `cocoindex.DataSlice[T]` where T is the type of the value.")
960
+ "Please use `cocoindex.DataSlice[T]` where T is the type of the value."
961
+ )
791
962
  arg_types.append(value_type_annotation)
792
963
 
793
964
  _transform_flow = TransformFlow(fn, arg_types)