cocoindex 0.2.3__cp311-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +92 -0
- cocoindex/_engine.pyd +0 -0
- cocoindex/auth_registry.py +51 -0
- cocoindex/cli.py +697 -0
- cocoindex/convert.py +621 -0
- cocoindex/flow.py +1205 -0
- cocoindex/functions.py +357 -0
- cocoindex/index.py +29 -0
- cocoindex/lib.py +32 -0
- cocoindex/llm.py +46 -0
- cocoindex/op.py +628 -0
- cocoindex/py.typed +0 -0
- cocoindex/runtime.py +37 -0
- cocoindex/setting.py +181 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources.py +102 -0
- cocoindex/subprocess_exec.py +279 -0
- cocoindex/targets.py +135 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/conftest.py +38 -0
- cocoindex/tests/test_convert.py +1543 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +207 -0
- cocoindex/tests/test_typing.py +429 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +473 -0
- cocoindex/user_app_loader.py +51 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.2.3.dist-info/METADATA +262 -0
- cocoindex-0.2.3.dist-info/RECORD +34 -0
- cocoindex-0.2.3.dist-info/WHEEL +4 -0
- cocoindex-0.2.3.dist-info/entry_points.txt +2 -0
- cocoindex-0.2.3.dist-info/licenses/LICENSE +201 -0
cocoindex/flow.py
ADDED
@@ -0,0 +1,1205 @@
|
|
1
|
+
"""
|
2
|
+
Flow is the main interface for building and running flows.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import datetime
|
9
|
+
import functools
|
10
|
+
import inspect
|
11
|
+
import re
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from enum import Enum
|
14
|
+
from threading import Lock
|
15
|
+
from typing import (
|
16
|
+
Any,
|
17
|
+
Callable,
|
18
|
+
Generic,
|
19
|
+
Iterable,
|
20
|
+
NamedTuple,
|
21
|
+
Sequence,
|
22
|
+
TypeVar,
|
23
|
+
cast,
|
24
|
+
get_args,
|
25
|
+
get_origin,
|
26
|
+
)
|
27
|
+
|
28
|
+
from rich.text import Text
|
29
|
+
from rich.tree import Tree
|
30
|
+
|
31
|
+
from . import _engine # type: ignore
|
32
|
+
from . import index
|
33
|
+
from . import op
|
34
|
+
from . import setting
|
35
|
+
from .convert import (
|
36
|
+
dump_engine_object,
|
37
|
+
make_engine_value_decoder,
|
38
|
+
make_engine_value_encoder,
|
39
|
+
)
|
40
|
+
from .op import FunctionSpec
|
41
|
+
from .runtime import execution_context
|
42
|
+
from .setup import SetupChangeBundle
|
43
|
+
from .typing import analyze_type_info, encode_enriched_type
|
44
|
+
from .validation import (
|
45
|
+
validate_flow_name,
|
46
|
+
validate_full_flow_name,
|
47
|
+
validate_target_name,
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
class _NameBuilder:
|
52
|
+
_existing_names: set[str]
|
53
|
+
_next_name_index: dict[str, int]
|
54
|
+
|
55
|
+
def __init__(self) -> None:
|
56
|
+
self._existing_names = set()
|
57
|
+
self._next_name_index = {}
|
58
|
+
|
59
|
+
def build_name(self, name: str | None, /, prefix: str) -> str:
|
60
|
+
"""
|
61
|
+
Build a name. If the name is None, generate a name with the given prefix.
|
62
|
+
"""
|
63
|
+
if name is not None:
|
64
|
+
self._existing_names.add(name)
|
65
|
+
return name
|
66
|
+
|
67
|
+
next_idx = self._next_name_index.get(prefix, 0)
|
68
|
+
while True:
|
69
|
+
name = f"{prefix}{next_idx}"
|
70
|
+
next_idx += 1
|
71
|
+
self._next_name_index[prefix] = next_idx
|
72
|
+
if name not in self._existing_names:
|
73
|
+
self._existing_names.add(name)
|
74
|
+
return name
|
75
|
+
|
76
|
+
|
77
|
+
_WORD_BOUNDARY_RE = re.compile("(?<!^)(?=[A-Z])")
|
78
|
+
|
79
|
+
|
80
|
+
def _to_snake_case(name: str) -> str:
|
81
|
+
return _WORD_BOUNDARY_RE.sub("_", name).lower()
|
82
|
+
|
83
|
+
|
84
|
+
def _create_data_slice(
|
85
|
+
flow_builder_state: _FlowBuilderState,
|
86
|
+
creator: Callable[[_engine.DataScopeRef | None, str | None], _engine.DataSlice],
|
87
|
+
name: str | None = None,
|
88
|
+
) -> DataSlice[T]:
|
89
|
+
if name is None:
|
90
|
+
return DataSlice(
|
91
|
+
_DataSliceState(
|
92
|
+
flow_builder_state,
|
93
|
+
lambda target: creator(target[0], target[1])
|
94
|
+
if target is not None
|
95
|
+
else creator(None, None),
|
96
|
+
)
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
return DataSlice(_DataSliceState(flow_builder_state, creator(None, name)))
|
100
|
+
|
101
|
+
|
102
|
+
def _spec_kind(spec: Any) -> str:
|
103
|
+
return cast(str, spec.__class__.__name__)
|
104
|
+
|
105
|
+
|
106
|
+
def _transform_helper(
|
107
|
+
flow_builder_state: _FlowBuilderState,
|
108
|
+
fn_spec: FunctionSpec | Callable[..., Any],
|
109
|
+
transform_args: list[tuple[Any, str | None]],
|
110
|
+
name: str | None = None,
|
111
|
+
) -> DataSlice[Any]:
|
112
|
+
if isinstance(fn_spec, FunctionSpec):
|
113
|
+
kind = _spec_kind(fn_spec)
|
114
|
+
spec = fn_spec
|
115
|
+
elif callable(fn_spec) and (
|
116
|
+
op_kind := getattr(fn_spec, "__cocoindex_op_kind__", None)
|
117
|
+
):
|
118
|
+
kind = op_kind
|
119
|
+
spec = op.EmptyFunctionSpec()
|
120
|
+
else:
|
121
|
+
raise ValueError("transform() can only be called on a CocoIndex function")
|
122
|
+
|
123
|
+
def _create_data_slice_inner(
|
124
|
+
target_scope: _engine.DataScopeRef | None, name: str | None
|
125
|
+
) -> _engine.DataSlice:
|
126
|
+
result = flow_builder_state.engine_flow_builder.transform(
|
127
|
+
kind,
|
128
|
+
dump_engine_object(spec),
|
129
|
+
transform_args,
|
130
|
+
target_scope,
|
131
|
+
flow_builder_state.field_name_builder.build_name(
|
132
|
+
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
|
133
|
+
),
|
134
|
+
)
|
135
|
+
return result
|
136
|
+
|
137
|
+
return _create_data_slice(
|
138
|
+
flow_builder_state,
|
139
|
+
_create_data_slice_inner,
|
140
|
+
name,
|
141
|
+
)
|
142
|
+
|
143
|
+
|
144
|
+
T = TypeVar("T")
|
145
|
+
S = TypeVar("S")
|
146
|
+
|
147
|
+
|
148
|
+
class _DataSliceState:
|
149
|
+
flow_builder_state: _FlowBuilderState
|
150
|
+
|
151
|
+
_lazy_lock: Lock | None = None # None means it's not lazy.
|
152
|
+
_data_slice: _engine.DataSlice | None = None
|
153
|
+
_data_slice_creator: (
|
154
|
+
Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice] | None
|
155
|
+
) = None
|
156
|
+
|
157
|
+
def __init__(
|
158
|
+
self,
|
159
|
+
flow_builder_state: _FlowBuilderState,
|
160
|
+
data_slice: _engine.DataSlice
|
161
|
+
| Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice],
|
162
|
+
):
|
163
|
+
self.flow_builder_state = flow_builder_state
|
164
|
+
|
165
|
+
if isinstance(data_slice, _engine.DataSlice):
|
166
|
+
self._data_slice = data_slice
|
167
|
+
else:
|
168
|
+
self._lazy_lock = Lock()
|
169
|
+
self._data_slice_creator = data_slice
|
170
|
+
|
171
|
+
@property
|
172
|
+
def engine_data_slice(self) -> _engine.DataSlice:
|
173
|
+
"""
|
174
|
+
Get the internal DataSlice.
|
175
|
+
This can be blocking.
|
176
|
+
"""
|
177
|
+
if self._lazy_lock is None:
|
178
|
+
if self._data_slice is None:
|
179
|
+
raise ValueError("Data slice is not initialized")
|
180
|
+
return self._data_slice
|
181
|
+
else:
|
182
|
+
if self._data_slice_creator is None:
|
183
|
+
raise ValueError("Data slice creator is not initialized")
|
184
|
+
with self._lazy_lock:
|
185
|
+
if self._data_slice is None:
|
186
|
+
self._data_slice = self._data_slice_creator(None)
|
187
|
+
return self._data_slice
|
188
|
+
|
189
|
+
async def engine_data_slice_async(self) -> _engine.DataSlice:
|
190
|
+
"""
|
191
|
+
Get the internal DataSlice.
|
192
|
+
This can be blocking.
|
193
|
+
"""
|
194
|
+
return await asyncio.to_thread(lambda: self.engine_data_slice)
|
195
|
+
|
196
|
+
def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
|
197
|
+
"""
|
198
|
+
Attach the current data slice (if not yet attached) to the given scope.
|
199
|
+
"""
|
200
|
+
if self._lazy_lock is not None:
|
201
|
+
with self._lazy_lock:
|
202
|
+
if self._data_slice_creator is None:
|
203
|
+
raise ValueError("Data slice creator is not initialized")
|
204
|
+
if self._data_slice is None:
|
205
|
+
self._data_slice = self._data_slice_creator((scope, field_name))
|
206
|
+
return
|
207
|
+
# TODO: We'll support this by an identity transformer or "aliasing" in the future.
|
208
|
+
raise ValueError("DataSlice is already attached to a field")
|
209
|
+
|
210
|
+
|
211
|
+
class DataSlice(Generic[T]):
|
212
|
+
"""A data slice represents a slice of data in a flow. It's readonly."""
|
213
|
+
|
214
|
+
_state: _DataSliceState
|
215
|
+
|
216
|
+
def __init__(self, state: _DataSliceState):
|
217
|
+
self._state = state
|
218
|
+
|
219
|
+
def __str__(self) -> str:
|
220
|
+
return str(self._state.engine_data_slice)
|
221
|
+
|
222
|
+
def __repr__(self) -> str:
|
223
|
+
return repr(self._state.engine_data_slice)
|
224
|
+
|
225
|
+
def __getitem__(self, field_name: str) -> DataSlice[T]:
|
226
|
+
field_slice = self._state.engine_data_slice.field(field_name)
|
227
|
+
if field_slice is None:
|
228
|
+
raise KeyError(field_name)
|
229
|
+
return DataSlice(_DataSliceState(self._state.flow_builder_state, field_slice))
|
230
|
+
|
231
|
+
def row(
|
232
|
+
self,
|
233
|
+
/,
|
234
|
+
*,
|
235
|
+
max_inflight_rows: int | None = None,
|
236
|
+
max_inflight_bytes: int | None = None,
|
237
|
+
) -> DataScope:
|
238
|
+
"""
|
239
|
+
Return a scope representing each row of the table.
|
240
|
+
"""
|
241
|
+
row_scope = self._state.flow_builder_state.engine_flow_builder.for_each(
|
242
|
+
self._state.engine_data_slice,
|
243
|
+
execution_options=dump_engine_object(
|
244
|
+
_ExecutionOptions(
|
245
|
+
max_inflight_rows=max_inflight_rows,
|
246
|
+
max_inflight_bytes=max_inflight_bytes,
|
247
|
+
),
|
248
|
+
),
|
249
|
+
)
|
250
|
+
return DataScope(self._state.flow_builder_state, row_scope)
|
251
|
+
|
252
|
+
def for_each(
|
253
|
+
self,
|
254
|
+
f: Callable[[DataScope], None],
|
255
|
+
/,
|
256
|
+
*,
|
257
|
+
max_inflight_rows: int | None = None,
|
258
|
+
max_inflight_bytes: int | None = None,
|
259
|
+
) -> None:
|
260
|
+
"""
|
261
|
+
Apply a function to each row of the collection.
|
262
|
+
"""
|
263
|
+
with self.row(
|
264
|
+
max_inflight_rows=max_inflight_rows,
|
265
|
+
max_inflight_bytes=max_inflight_bytes,
|
266
|
+
) as scope:
|
267
|
+
f(scope)
|
268
|
+
|
269
|
+
def transform(
|
270
|
+
self, fn_spec: op.FunctionSpec | Callable[..., Any], *args: Any, **kwargs: Any
|
271
|
+
) -> DataSlice[Any]:
|
272
|
+
"""
|
273
|
+
Apply a function to the data slice.
|
274
|
+
"""
|
275
|
+
transform_args: list[tuple[Any, str | None]] = [
|
276
|
+
(self._state.engine_data_slice, None)
|
277
|
+
]
|
278
|
+
transform_args += [
|
279
|
+
(self._state.flow_builder_state.get_data_slice(v), None) for v in args
|
280
|
+
]
|
281
|
+
transform_args += [
|
282
|
+
(self._state.flow_builder_state.get_data_slice(v), k)
|
283
|
+
for k, v in kwargs.items()
|
284
|
+
]
|
285
|
+
|
286
|
+
return _transform_helper(
|
287
|
+
self._state.flow_builder_state, fn_spec, transform_args
|
288
|
+
)
|
289
|
+
|
290
|
+
def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S:
|
291
|
+
"""
|
292
|
+
Call a function with the data slice.
|
293
|
+
"""
|
294
|
+
return func(self, *args, **kwargs)
|
295
|
+
|
296
|
+
|
297
|
+
def _data_slice_state(data_slice: DataSlice[T]) -> _DataSliceState:
|
298
|
+
return data_slice._state # pylint: disable=protected-access
|
299
|
+
|
300
|
+
|
301
|
+
class DataScope:
|
302
|
+
"""
|
303
|
+
A data scope in a flow.
|
304
|
+
It has multple fields and collectors, and allow users to add new fields and collectors.
|
305
|
+
"""
|
306
|
+
|
307
|
+
_flow_builder_state: _FlowBuilderState
|
308
|
+
_engine_data_scope: _engine.DataScopeRef
|
309
|
+
|
310
|
+
def __init__(
|
311
|
+
self, flow_builder_state: _FlowBuilderState, data_scope: _engine.DataScopeRef
|
312
|
+
):
|
313
|
+
self._flow_builder_state = flow_builder_state
|
314
|
+
self._engine_data_scope = data_scope
|
315
|
+
|
316
|
+
def __str__(self) -> str:
|
317
|
+
return str(self._engine_data_scope)
|
318
|
+
|
319
|
+
def __repr__(self) -> str:
|
320
|
+
return repr(self._engine_data_scope)
|
321
|
+
|
322
|
+
def __getitem__(self, field_name: str) -> DataSlice[T]:
|
323
|
+
return DataSlice(
|
324
|
+
_DataSliceState(
|
325
|
+
self._flow_builder_state,
|
326
|
+
self._flow_builder_state.engine_flow_builder.scope_field(
|
327
|
+
self._engine_data_scope, field_name
|
328
|
+
),
|
329
|
+
)
|
330
|
+
)
|
331
|
+
|
332
|
+
def __setitem__(self, field_name: str, value: DataSlice[T]) -> None:
|
333
|
+
from .validation import validate_field_name
|
334
|
+
|
335
|
+
validate_field_name(field_name)
|
336
|
+
value._state.attach_to_scope(self._engine_data_scope, field_name)
|
337
|
+
|
338
|
+
def __enter__(self) -> DataScope:
|
339
|
+
return self
|
340
|
+
|
341
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
342
|
+
del self._engine_data_scope
|
343
|
+
|
344
|
+
def add_collector(self, name: str | None = None) -> DataCollector:
|
345
|
+
"""
|
346
|
+
Add a collector to the flow.
|
347
|
+
"""
|
348
|
+
return DataCollector(
|
349
|
+
self._flow_builder_state,
|
350
|
+
self._engine_data_scope.add_collector(
|
351
|
+
self._flow_builder_state.field_name_builder.build_name(
|
352
|
+
name, prefix="_collector_"
|
353
|
+
)
|
354
|
+
),
|
355
|
+
)
|
356
|
+
|
357
|
+
|
358
|
+
class GeneratedField(Enum):
|
359
|
+
"""
|
360
|
+
A generated field is automatically set by the engine.
|
361
|
+
"""
|
362
|
+
|
363
|
+
UUID = "Uuid"
|
364
|
+
|
365
|
+
|
366
|
+
class DataCollector:
|
367
|
+
"""A data collector is used to collect data into a collector."""
|
368
|
+
|
369
|
+
_flow_builder_state: _FlowBuilderState
|
370
|
+
_engine_data_collector: _engine.DataCollector
|
371
|
+
|
372
|
+
def __init__(
|
373
|
+
self,
|
374
|
+
flow_builder_state: _FlowBuilderState,
|
375
|
+
data_collector: _engine.DataCollector,
|
376
|
+
):
|
377
|
+
self._flow_builder_state = flow_builder_state
|
378
|
+
self._engine_data_collector = data_collector
|
379
|
+
|
380
|
+
def collect(self, **kwargs: Any) -> None:
|
381
|
+
"""
|
382
|
+
Collect data into the collector.
|
383
|
+
"""
|
384
|
+
regular_kwargs = []
|
385
|
+
auto_uuid_field = None
|
386
|
+
for k, v in kwargs.items():
|
387
|
+
if isinstance(v, GeneratedField):
|
388
|
+
if v == GeneratedField.UUID:
|
389
|
+
if auto_uuid_field is not None:
|
390
|
+
raise ValueError("Only one generated UUID field is allowed")
|
391
|
+
auto_uuid_field = k
|
392
|
+
else:
|
393
|
+
raise ValueError(f"Unexpected generated field: {v}")
|
394
|
+
else:
|
395
|
+
regular_kwargs.append((k, self._flow_builder_state.get_data_slice(v)))
|
396
|
+
|
397
|
+
self._flow_builder_state.engine_flow_builder.collect(
|
398
|
+
self._engine_data_collector, regular_kwargs, auto_uuid_field
|
399
|
+
)
|
400
|
+
|
401
|
+
def export(
|
402
|
+
self,
|
403
|
+
target_name: str,
|
404
|
+
target_spec: op.TargetSpec,
|
405
|
+
/,
|
406
|
+
*,
|
407
|
+
primary_key_fields: Sequence[str],
|
408
|
+
vector_indexes: Sequence[index.VectorIndexDef] = (),
|
409
|
+
vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (),
|
410
|
+
setup_by_user: bool = False,
|
411
|
+
) -> None:
|
412
|
+
"""
|
413
|
+
Export the collected data to the specified target.
|
414
|
+
|
415
|
+
`vector_index` is for backward compatibility only. Please use `vector_indexes` instead.
|
416
|
+
"""
|
417
|
+
|
418
|
+
validate_target_name(target_name)
|
419
|
+
if not isinstance(target_spec, op.TargetSpec):
|
420
|
+
raise ValueError(
|
421
|
+
"export() can only be called on a CocoIndex target storage"
|
422
|
+
)
|
423
|
+
|
424
|
+
# For backward compatibility only.
|
425
|
+
if len(vector_indexes) == 0 and len(vector_index) > 0:
|
426
|
+
vector_indexes = [
|
427
|
+
index.VectorIndexDef(field_name=field_name, metric=metric)
|
428
|
+
for field_name, metric in vector_index
|
429
|
+
]
|
430
|
+
|
431
|
+
index_options = index.IndexOptions(
|
432
|
+
primary_key_fields=primary_key_fields,
|
433
|
+
vector_indexes=vector_indexes,
|
434
|
+
)
|
435
|
+
self._flow_builder_state.engine_flow_builder.export(
|
436
|
+
target_name,
|
437
|
+
_spec_kind(target_spec),
|
438
|
+
dump_engine_object(target_spec),
|
439
|
+
dump_engine_object(index_options),
|
440
|
+
self._engine_data_collector,
|
441
|
+
setup_by_user,
|
442
|
+
)
|
443
|
+
|
444
|
+
|
445
|
+
_flow_name_builder = _NameBuilder()
|
446
|
+
|
447
|
+
|
448
|
+
class _FlowBuilderState:
|
449
|
+
"""
|
450
|
+
A flow builder is used to build a flow.
|
451
|
+
"""
|
452
|
+
|
453
|
+
engine_flow_builder: _engine.FlowBuilder
|
454
|
+
field_name_builder: _NameBuilder
|
455
|
+
|
456
|
+
def __init__(self, full_name: str):
|
457
|
+
self.engine_flow_builder = _engine.FlowBuilder(full_name)
|
458
|
+
self.field_name_builder = _NameBuilder()
|
459
|
+
|
460
|
+
def get_data_slice(self, v: Any) -> _engine.DataSlice:
|
461
|
+
"""
|
462
|
+
Return a data slice that represents the given value.
|
463
|
+
"""
|
464
|
+
if isinstance(v, DataSlice):
|
465
|
+
return v._state.engine_data_slice
|
466
|
+
return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v)
|
467
|
+
|
468
|
+
|
469
|
+
@dataclass
|
470
|
+
class _SourceRefreshOptions:
|
471
|
+
"""
|
472
|
+
Options for refreshing a source.
|
473
|
+
"""
|
474
|
+
|
475
|
+
refresh_interval: datetime.timedelta | None = None
|
476
|
+
|
477
|
+
|
478
|
+
@dataclass
|
479
|
+
class _ExecutionOptions:
|
480
|
+
max_inflight_rows: int | None = None
|
481
|
+
max_inflight_bytes: int | None = None
|
482
|
+
|
483
|
+
|
484
|
+
class FlowBuilder:
|
485
|
+
"""
|
486
|
+
A flow builder is used to build a flow.
|
487
|
+
"""
|
488
|
+
|
489
|
+
_state: _FlowBuilderState
|
490
|
+
|
491
|
+
def __init__(self, state: _FlowBuilderState):
|
492
|
+
self._state = state
|
493
|
+
|
494
|
+
def __str__(self) -> str:
|
495
|
+
return str(self._state.engine_flow_builder)
|
496
|
+
|
497
|
+
def __repr__(self) -> str:
|
498
|
+
return repr(self._state.engine_flow_builder)
|
499
|
+
|
500
|
+
def add_source(
|
501
|
+
self,
|
502
|
+
spec: op.SourceSpec,
|
503
|
+
/,
|
504
|
+
*,
|
505
|
+
name: str | None = None,
|
506
|
+
refresh_interval: datetime.timedelta | None = None,
|
507
|
+
max_inflight_rows: int | None = None,
|
508
|
+
max_inflight_bytes: int | None = None,
|
509
|
+
) -> DataSlice[T]:
|
510
|
+
"""
|
511
|
+
Import a source to the flow.
|
512
|
+
"""
|
513
|
+
if not isinstance(spec, op.SourceSpec):
|
514
|
+
raise ValueError("add_source() can only be called on a CocoIndex source")
|
515
|
+
return _create_data_slice(
|
516
|
+
self._state,
|
517
|
+
lambda target_scope, name: self._state.engine_flow_builder.add_source(
|
518
|
+
_spec_kind(spec),
|
519
|
+
dump_engine_object(spec),
|
520
|
+
target_scope,
|
521
|
+
self._state.field_name_builder.build_name(
|
522
|
+
name, prefix=_to_snake_case(_spec_kind(spec)) + "_"
|
523
|
+
),
|
524
|
+
refresh_options=dump_engine_object(
|
525
|
+
_SourceRefreshOptions(refresh_interval=refresh_interval)
|
526
|
+
),
|
527
|
+
execution_options=dump_engine_object(
|
528
|
+
_ExecutionOptions(
|
529
|
+
max_inflight_rows=max_inflight_rows,
|
530
|
+
max_inflight_bytes=max_inflight_bytes,
|
531
|
+
)
|
532
|
+
),
|
533
|
+
),
|
534
|
+
name,
|
535
|
+
)
|
536
|
+
|
537
|
+
def transform(
|
538
|
+
self, fn_spec: FunctionSpec | Callable[..., Any], *args: Any, **kwargs: Any
|
539
|
+
) -> DataSlice[Any]:
|
540
|
+
"""
|
541
|
+
Apply a function to inputs, returning a DataSlice.
|
542
|
+
"""
|
543
|
+
transform_args: list[tuple[Any, str | None]] = [
|
544
|
+
(self._state.get_data_slice(v), None) for v in args
|
545
|
+
]
|
546
|
+
transform_args += [
|
547
|
+
(self._state.get_data_slice(v), k) for k, v in kwargs.items()
|
548
|
+
]
|
549
|
+
|
550
|
+
if not transform_args:
|
551
|
+
raise ValueError("At least one input is required for transformation")
|
552
|
+
|
553
|
+
return _transform_helper(self._state, fn_spec, transform_args)
|
554
|
+
|
555
|
+
def declare(self, spec: op.DeclarationSpec) -> None:
|
556
|
+
"""
|
557
|
+
Add a declaration to the flow.
|
558
|
+
"""
|
559
|
+
self._state.engine_flow_builder.declare(dump_engine_object(spec))
|
560
|
+
|
561
|
+
|
562
|
+
@dataclass
|
563
|
+
class FlowLiveUpdaterOptions:
|
564
|
+
"""
|
565
|
+
Options for live updating a flow.
|
566
|
+
"""
|
567
|
+
|
568
|
+
live_mode: bool = True
|
569
|
+
print_stats: bool = False
|
570
|
+
|
571
|
+
|
572
|
+
class FlowUpdaterStatusUpdates(NamedTuple):
|
573
|
+
"""
|
574
|
+
Status updates for a flow updater.
|
575
|
+
"""
|
576
|
+
|
577
|
+
# Sources that are still active, i.e. not stopped processing.
|
578
|
+
active_sources: list[str]
|
579
|
+
|
580
|
+
# Sources with updates since last time.
|
581
|
+
updated_sources: list[str]
|
582
|
+
|
583
|
+
|
584
|
+
class FlowLiveUpdater:
|
585
|
+
"""
|
586
|
+
A live updater for a flow.
|
587
|
+
"""
|
588
|
+
|
589
|
+
_flow: Flow
|
590
|
+
_options: FlowLiveUpdaterOptions
|
591
|
+
_engine_live_updater: _engine.FlowLiveUpdater | None = None
|
592
|
+
|
593
|
+
def __init__(self, fl: Flow, options: FlowLiveUpdaterOptions | None = None):
|
594
|
+
self._flow = fl
|
595
|
+
self._options = options or FlowLiveUpdaterOptions()
|
596
|
+
|
597
|
+
def __enter__(self) -> FlowLiveUpdater:
|
598
|
+
self.start()
|
599
|
+
return self
|
600
|
+
|
601
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
602
|
+
self.abort()
|
603
|
+
self.wait()
|
604
|
+
|
605
|
+
async def __aenter__(self) -> FlowLiveUpdater:
|
606
|
+
await self.start_async()
|
607
|
+
return self
|
608
|
+
|
609
|
+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
610
|
+
self.abort()
|
611
|
+
await self.wait_async()
|
612
|
+
|
613
|
+
def start(self) -> None:
|
614
|
+
"""
|
615
|
+
Start the live updater.
|
616
|
+
"""
|
617
|
+
execution_context.run(self.start_async())
|
618
|
+
|
619
|
+
async def start_async(self) -> None:
|
620
|
+
"""
|
621
|
+
Start the live updater.
|
622
|
+
"""
|
623
|
+
self._engine_live_updater = await _engine.FlowLiveUpdater.create(
|
624
|
+
await self._flow.internal_flow_async(), dump_engine_object(self._options)
|
625
|
+
)
|
626
|
+
|
627
|
+
def wait(self) -> None:
|
628
|
+
"""
|
629
|
+
Wait for the live updater to finish.
|
630
|
+
"""
|
631
|
+
execution_context.run(self.wait_async())
|
632
|
+
|
633
|
+
async def wait_async(self) -> None:
|
634
|
+
"""
|
635
|
+
Wait for the live updater to finish. Async version.
|
636
|
+
"""
|
637
|
+
await self._get_engine_live_updater().wait_async()
|
638
|
+
|
639
|
+
def next_status_updates(self) -> FlowUpdaterStatusUpdates:
|
640
|
+
"""
|
641
|
+
Get the next status updates.
|
642
|
+
|
643
|
+
It blocks until there's a new status updates, including the processing finishes for a bunch of source updates,
|
644
|
+
and live updater stops (aborted, or no more sources to process).
|
645
|
+
"""
|
646
|
+
return execution_context.run(self.next_status_updates_async())
|
647
|
+
|
648
|
+
async def next_status_updates_async(self) -> FlowUpdaterStatusUpdates:
|
649
|
+
"""
|
650
|
+
Get the next status updates. Async version.
|
651
|
+
"""
|
652
|
+
updates = await self._get_engine_live_updater().next_status_updates_async()
|
653
|
+
return FlowUpdaterStatusUpdates(
|
654
|
+
active_sources=updates.active_sources,
|
655
|
+
updated_sources=updates.updated_sources,
|
656
|
+
)
|
657
|
+
|
658
|
+
def abort(self) -> None:
|
659
|
+
"""
|
660
|
+
Abort the live updater.
|
661
|
+
"""
|
662
|
+
self._get_engine_live_updater().abort()
|
663
|
+
|
664
|
+
def update_stats(self) -> _engine.IndexUpdateInfo:
|
665
|
+
"""
|
666
|
+
Get the index update info.
|
667
|
+
"""
|
668
|
+
return self._get_engine_live_updater().index_update_info()
|
669
|
+
|
670
|
+
def _get_engine_live_updater(self) -> _engine.FlowLiveUpdater:
|
671
|
+
if self._engine_live_updater is None:
|
672
|
+
raise RuntimeError("Live updater is not started")
|
673
|
+
return self._engine_live_updater
|
674
|
+
|
675
|
+
|
676
|
+
@dataclass
|
677
|
+
class EvaluateAndDumpOptions:
|
678
|
+
"""
|
679
|
+
Options for evaluating and dumping a flow.
|
680
|
+
"""
|
681
|
+
|
682
|
+
output_dir: str
|
683
|
+
use_cache: bool = True
|
684
|
+
|
685
|
+
|
686
|
+
class Flow:
|
687
|
+
"""
|
688
|
+
A flow describes an indexing pipeline.
|
689
|
+
"""
|
690
|
+
|
691
|
+
_name: str
|
692
|
+
_full_name: str
|
693
|
+
_lazy_engine_flow: Callable[[], _engine.Flow] | None
|
694
|
+
|
695
|
+
def __init__(
|
696
|
+
self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow]
|
697
|
+
):
|
698
|
+
validate_flow_name(name)
|
699
|
+
validate_full_flow_name(full_name)
|
700
|
+
self._name = name
|
701
|
+
self._full_name = full_name
|
702
|
+
engine_flow = None
|
703
|
+
lock = Lock()
|
704
|
+
|
705
|
+
def _lazy_engine_flow() -> _engine.Flow:
|
706
|
+
nonlocal engine_flow, lock
|
707
|
+
if engine_flow is None:
|
708
|
+
with lock:
|
709
|
+
if engine_flow is None:
|
710
|
+
engine_flow = engine_flow_creator()
|
711
|
+
return engine_flow
|
712
|
+
|
713
|
+
self._lazy_engine_flow = _lazy_engine_flow
|
714
|
+
|
715
|
+
def _render_spec(self, verbose: bool = False) -> Tree:
|
716
|
+
"""
|
717
|
+
Render the flow spec as a styled rich Tree with hierarchical structure.
|
718
|
+
"""
|
719
|
+
spec = self._get_spec(verbose=verbose)
|
720
|
+
tree = Tree(f"Flow: {self.full_name}", style="cyan")
|
721
|
+
|
722
|
+
def build_tree(label: str, lines: list[Any]) -> Tree:
|
723
|
+
node = Tree(label=label if lines else label + " None", style="cyan")
|
724
|
+
for line in lines:
|
725
|
+
child_node = node.add(Text(line.content, style="yellow"))
|
726
|
+
child_node.children = build_tree("", line.children).children
|
727
|
+
return node
|
728
|
+
|
729
|
+
for section, lines in spec.sections:
|
730
|
+
section_node = build_tree(f"{section}:", lines)
|
731
|
+
tree.children.append(section_node)
|
732
|
+
return tree
|
733
|
+
|
734
|
+
def _get_spec(self, verbose: bool = False) -> _engine.RenderedSpec:
|
735
|
+
return self.internal_flow().get_spec(
|
736
|
+
output_mode="verbose" if verbose else "concise"
|
737
|
+
)
|
738
|
+
|
739
|
+
def _get_schema(self) -> list[tuple[str, str, str]]:
|
740
|
+
return cast(list[tuple[str, str, str]], self.internal_flow().get_schema())
|
741
|
+
|
742
|
+
def __str__(self) -> str:
|
743
|
+
return str(self._get_spec())
|
744
|
+
|
745
|
+
def __repr__(self) -> str:
|
746
|
+
return repr(self.internal_flow())
|
747
|
+
|
748
|
+
@property
|
749
|
+
def name(self) -> str:
|
750
|
+
"""
|
751
|
+
Get the name of the flow.
|
752
|
+
"""
|
753
|
+
return self._name
|
754
|
+
|
755
|
+
@property
|
756
|
+
def full_name(self) -> str:
|
757
|
+
"""
|
758
|
+
Get the full name of the flow.
|
759
|
+
"""
|
760
|
+
return self._full_name
|
761
|
+
|
762
|
+
def update(self) -> _engine.IndexUpdateInfo:
|
763
|
+
"""
|
764
|
+
Update the index defined by the flow.
|
765
|
+
Once the function returns, the index is fresh up to the moment when the function is called.
|
766
|
+
"""
|
767
|
+
return execution_context.run(self.update_async())
|
768
|
+
|
769
|
+
async def update_async(self) -> _engine.IndexUpdateInfo:
|
770
|
+
"""
|
771
|
+
Update the index defined by the flow.
|
772
|
+
Once the function returns, the index is fresh up to the moment when the function is called.
|
773
|
+
"""
|
774
|
+
async with FlowLiveUpdater(
|
775
|
+
self, FlowLiveUpdaterOptions(live_mode=False)
|
776
|
+
) as updater:
|
777
|
+
await updater.wait_async()
|
778
|
+
return updater.update_stats()
|
779
|
+
|
780
|
+
def evaluate_and_dump(
|
781
|
+
self, options: EvaluateAndDumpOptions
|
782
|
+
) -> _engine.IndexUpdateInfo:
|
783
|
+
"""
|
784
|
+
Evaluate the flow and dump flow outputs to files.
|
785
|
+
"""
|
786
|
+
return self.internal_flow().evaluate_and_dump(dump_engine_object(options))
|
787
|
+
|
788
|
+
def internal_flow(self) -> _engine.Flow:
|
789
|
+
"""
|
790
|
+
Get the engine flow.
|
791
|
+
"""
|
792
|
+
if self._lazy_engine_flow is None:
|
793
|
+
raise RuntimeError(f"Flow {self.full_name} is already removed")
|
794
|
+
return self._lazy_engine_flow()
|
795
|
+
|
796
|
+
async def internal_flow_async(self) -> _engine.Flow:
|
797
|
+
"""
|
798
|
+
Get the engine flow. The async version.
|
799
|
+
"""
|
800
|
+
return await asyncio.to_thread(self.internal_flow)
|
801
|
+
|
802
|
+
def setup(self, report_to_stdout: bool = False) -> None:
|
803
|
+
"""
|
804
|
+
Setup persistent backends of the flow.
|
805
|
+
"""
|
806
|
+
execution_context.run(self.setup_async(report_to_stdout=report_to_stdout))
|
807
|
+
|
808
|
+
async def setup_async(self, report_to_stdout: bool = False) -> None:
|
809
|
+
"""
|
810
|
+
Setup persistent backends of the flow. The async version.
|
811
|
+
"""
|
812
|
+
bundle = await make_setup_bundle_async([self])
|
813
|
+
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)
|
814
|
+
|
815
|
+
def drop(self, report_to_stdout: bool = False) -> None:
|
816
|
+
"""
|
817
|
+
Drop persistent backends of the flow.
|
818
|
+
|
819
|
+
The current instance is still valid after it's called.
|
820
|
+
For example, you can still call `setup()` after it, to setup the persistent backends again.
|
821
|
+
|
822
|
+
Call `close()` if you want to remove the flow from the current process.
|
823
|
+
"""
|
824
|
+
execution_context.run(self.drop_async(report_to_stdout=report_to_stdout))
|
825
|
+
|
826
|
+
async def drop_async(self, report_to_stdout: bool = False) -> None:
|
827
|
+
"""
|
828
|
+
Drop persistent backends of the flow. The async version.
|
829
|
+
"""
|
830
|
+
bundle = await make_drop_bundle_async([self])
|
831
|
+
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)
|
832
|
+
|
833
|
+
def close(self) -> None:
|
834
|
+
"""
|
835
|
+
Close the flow. It will remove the flow from the current process to free up resources.
|
836
|
+
After it's called, methods of the flow should no longer be called.
|
837
|
+
|
838
|
+
This will NOT touch the persistent backends of the flow.
|
839
|
+
"""
|
840
|
+
_engine.remove_flow_context(self.full_name)
|
841
|
+
self._lazy_engine_flow = None
|
842
|
+
with _flows_lock:
|
843
|
+
del _flows[self.name]
|
844
|
+
|
845
|
+
|
846
|
+
def _create_lazy_flow(
|
847
|
+
name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]
|
848
|
+
) -> Flow:
|
849
|
+
"""
|
850
|
+
Create a flow without really building it yet.
|
851
|
+
The flow will be built the first time when it's really needed.
|
852
|
+
"""
|
853
|
+
flow_name = _flow_name_builder.build_name(name, prefix="_flow_")
|
854
|
+
flow_full_name = get_flow_full_name(flow_name)
|
855
|
+
|
856
|
+
def _create_engine_flow() -> _engine.Flow:
|
857
|
+
flow_builder_state = _FlowBuilderState(flow_full_name)
|
858
|
+
root_scope = DataScope(
|
859
|
+
flow_builder_state, flow_builder_state.engine_flow_builder.root_scope()
|
860
|
+
)
|
861
|
+
fl_def(FlowBuilder(flow_builder_state), root_scope)
|
862
|
+
return flow_builder_state.engine_flow_builder.build_flow(
|
863
|
+
execution_context.event_loop
|
864
|
+
)
|
865
|
+
|
866
|
+
return Flow(flow_name, flow_full_name, _create_engine_flow)
|
867
|
+
|
868
|
+
|
869
|
+
_flows_lock = Lock()
|
870
|
+
_flows: dict[str, Flow] = {}
|
871
|
+
|
872
|
+
|
873
|
+
def get_flow_full_name(name: str) -> str:
|
874
|
+
"""
|
875
|
+
Get the full name of a flow.
|
876
|
+
"""
|
877
|
+
return f"{setting.get_app_namespace(trailing_delimiter='.')}{name}"
|
878
|
+
|
879
|
+
|
880
|
+
def open_flow(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
|
881
|
+
"""
|
882
|
+
Open a flow, with the given name and definition.
|
883
|
+
"""
|
884
|
+
with _flows_lock:
|
885
|
+
if name in _flows:
|
886
|
+
raise KeyError(f"Flow with name {name} already exists")
|
887
|
+
fl = _flows[name] = _create_lazy_flow(name, fl_def)
|
888
|
+
return fl
|
889
|
+
|
890
|
+
|
891
|
+
def add_flow_def(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
|
892
|
+
"""
|
893
|
+
DEPRECATED: Use `open_flow()` instead.
|
894
|
+
"""
|
895
|
+
return open_flow(name, fl_def)
|
896
|
+
|
897
|
+
|
898
|
+
def remove_flow(fl: Flow) -> None:
|
899
|
+
"""
|
900
|
+
DEPRECATED: Use `Flow.close()` instead.
|
901
|
+
"""
|
902
|
+
fl.close()
|
903
|
+
|
904
|
+
|
905
|
+
def flow_def(
|
906
|
+
name: str | None = None,
|
907
|
+
) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]:
|
908
|
+
"""
|
909
|
+
A decorator to wrap the flow definition.
|
910
|
+
"""
|
911
|
+
return lambda fl_def: open_flow(name or fl_def.__name__, fl_def)
|
912
|
+
|
913
|
+
|
914
|
+
def flow_names() -> list[str]:
|
915
|
+
"""
|
916
|
+
Get the names of all flows.
|
917
|
+
"""
|
918
|
+
with _flows_lock:
|
919
|
+
return list(_flows.keys())
|
920
|
+
|
921
|
+
|
922
|
+
def flows() -> dict[str, Flow]:
|
923
|
+
"""
|
924
|
+
Get all flows.
|
925
|
+
"""
|
926
|
+
with _flows_lock:
|
927
|
+
return dict(_flows)
|
928
|
+
|
929
|
+
|
930
|
+
def flow_by_name(name: str) -> Flow:
|
931
|
+
"""
|
932
|
+
Get a flow by name.
|
933
|
+
"""
|
934
|
+
with _flows_lock:
|
935
|
+
return _flows[name]
|
936
|
+
|
937
|
+
|
938
|
+
def ensure_all_flows_built() -> None:
|
939
|
+
"""
|
940
|
+
Ensure all flows are built.
|
941
|
+
"""
|
942
|
+
execution_context.run(ensure_all_flows_built_async())
|
943
|
+
|
944
|
+
|
945
|
+
async def ensure_all_flows_built_async() -> None:
|
946
|
+
"""
|
947
|
+
Ensure all flows are built.
|
948
|
+
"""
|
949
|
+
for fl in flows().values():
|
950
|
+
await fl.internal_flow_async()
|
951
|
+
|
952
|
+
|
953
|
+
def update_all_flows(
|
954
|
+
options: FlowLiveUpdaterOptions,
|
955
|
+
) -> dict[str, _engine.IndexUpdateInfo]:
|
956
|
+
"""
|
957
|
+
Update all flows.
|
958
|
+
"""
|
959
|
+
return execution_context.run(update_all_flows_async(options))
|
960
|
+
|
961
|
+
|
962
|
+
async def update_all_flows_async(
|
963
|
+
options: FlowLiveUpdaterOptions,
|
964
|
+
) -> dict[str, _engine.IndexUpdateInfo]:
|
965
|
+
"""
|
966
|
+
Update all flows.
|
967
|
+
"""
|
968
|
+
await ensure_all_flows_built_async()
|
969
|
+
|
970
|
+
async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInfo]:
|
971
|
+
async with FlowLiveUpdater(fl, options) as updater:
|
972
|
+
await updater.wait_async()
|
973
|
+
return (name, updater.update_stats())
|
974
|
+
|
975
|
+
fls = flows()
|
976
|
+
all_stats = await asyncio.gather(
|
977
|
+
*(_update_flow(name, fl) for (name, fl) in fls.items())
|
978
|
+
)
|
979
|
+
return dict(all_stats)
|
980
|
+
|
981
|
+
|
982
|
+
def _get_data_slice_annotation_type(
|
983
|
+
data_slice_type: type[DataSlice[T] | inspect._empty],
|
984
|
+
) -> type[T] | None:
|
985
|
+
type_args = get_args(data_slice_type)
|
986
|
+
if data_slice_type is inspect.Parameter.empty or data_slice_type is DataSlice:
|
987
|
+
return None
|
988
|
+
if get_origin(data_slice_type) != DataSlice or len(type_args) != 1:
|
989
|
+
raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}")
|
990
|
+
return cast(type[T] | None, type_args[0])
|
991
|
+
|
992
|
+
|
993
|
+
_transform_flow_name_builder = _NameBuilder()
|
994
|
+
|
995
|
+
|
996
|
+
class TransformFlowInfo(NamedTuple):
|
997
|
+
engine_flow: _engine.TransientFlow
|
998
|
+
result_decoder: Callable[[Any], T]
|
999
|
+
|
1000
|
+
|
1001
|
+
class FlowArgInfo(NamedTuple):
|
1002
|
+
name: str
|
1003
|
+
type_hint: Any
|
1004
|
+
encoder: Callable[[Any], Any]
|
1005
|
+
|
1006
|
+
|
1007
|
+
class TransformFlow(Generic[T]):
|
1008
|
+
"""
|
1009
|
+
A transient transformation flow that transforms in-memory data.
|
1010
|
+
"""
|
1011
|
+
|
1012
|
+
_flow_fn: Callable[..., DataSlice[T]]
|
1013
|
+
_flow_name: str
|
1014
|
+
_args_info: list[FlowArgInfo]
|
1015
|
+
|
1016
|
+
_lazy_lock: asyncio.Lock
|
1017
|
+
_lazy_flow_info: TransformFlowInfo | None = None
|
1018
|
+
|
1019
|
+
def __init__(
|
1020
|
+
self,
|
1021
|
+
flow_fn: Callable[..., DataSlice[T]],
|
1022
|
+
/,
|
1023
|
+
name: str | None = None,
|
1024
|
+
):
|
1025
|
+
self._flow_fn = flow_fn
|
1026
|
+
self._flow_name = _transform_flow_name_builder.build_name(
|
1027
|
+
name, prefix="_transform_flow_"
|
1028
|
+
)
|
1029
|
+
self._lazy_lock = asyncio.Lock()
|
1030
|
+
|
1031
|
+
sig = inspect.signature(flow_fn)
|
1032
|
+
args_info = []
|
1033
|
+
for param_name, param in sig.parameters.items():
|
1034
|
+
if param.kind not in (
|
1035
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
1036
|
+
inspect.Parameter.KEYWORD_ONLY,
|
1037
|
+
):
|
1038
|
+
raise ValueError(
|
1039
|
+
f"Parameter `{param_name}` is not a parameter can be passed by name"
|
1040
|
+
)
|
1041
|
+
value_type_annotation: type | None = _get_data_slice_annotation_type(
|
1042
|
+
param.annotation
|
1043
|
+
)
|
1044
|
+
if value_type_annotation is None:
|
1045
|
+
raise ValueError(
|
1046
|
+
f"Parameter `{param_name}` for {flow_fn} has no value type annotation. "
|
1047
|
+
"Please use `cocoindex.DataSlice[T]` where T is the type of the value."
|
1048
|
+
)
|
1049
|
+
encoder = make_engine_value_encoder(
|
1050
|
+
analyze_type_info(value_type_annotation)
|
1051
|
+
)
|
1052
|
+
args_info.append(FlowArgInfo(param_name, value_type_annotation, encoder))
|
1053
|
+
self._args_info = args_info
|
1054
|
+
|
1055
|
+
def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[T]:
|
1056
|
+
return self._flow_fn(*args, **kwargs)
|
1057
|
+
|
1058
|
+
@property
|
1059
|
+
def _flow_info(self) -> TransformFlowInfo:
|
1060
|
+
if self._lazy_flow_info is not None:
|
1061
|
+
return self._lazy_flow_info
|
1062
|
+
return execution_context.run(self._flow_info_async())
|
1063
|
+
|
1064
|
+
async def _flow_info_async(self) -> TransformFlowInfo:
|
1065
|
+
if self._lazy_flow_info is not None:
|
1066
|
+
return self._lazy_flow_info
|
1067
|
+
async with self._lazy_lock:
|
1068
|
+
if self._lazy_flow_info is None:
|
1069
|
+
self._lazy_flow_info = await self._build_flow_info_async()
|
1070
|
+
return self._lazy_flow_info
|
1071
|
+
|
1072
|
+
async def _build_flow_info_async(self) -> TransformFlowInfo:
|
1073
|
+
flow_builder_state = _FlowBuilderState(self._flow_name)
|
1074
|
+
kwargs: dict[str, DataSlice[T]] = {}
|
1075
|
+
for arg_info in self._args_info:
|
1076
|
+
encoded_type = encode_enriched_type(arg_info.type_hint)
|
1077
|
+
if encoded_type is None:
|
1078
|
+
raise ValueError(f"Parameter `{arg_info.name}` has no type annotation")
|
1079
|
+
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
|
1080
|
+
arg_info.name, encoded_type
|
1081
|
+
)
|
1082
|
+
kwargs[arg_info.name] = DataSlice(
|
1083
|
+
_DataSliceState(flow_builder_state, engine_ds)
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
output = await asyncio.to_thread(lambda: self._flow_fn(**kwargs))
|
1087
|
+
output_data_slice = await _data_slice_state(output).engine_data_slice_async()
|
1088
|
+
|
1089
|
+
flow_builder_state.engine_flow_builder.set_direct_output(output_data_slice)
|
1090
|
+
engine_flow = (
|
1091
|
+
await flow_builder_state.engine_flow_builder.build_transient_flow_async(
|
1092
|
+
execution_context.event_loop
|
1093
|
+
)
|
1094
|
+
)
|
1095
|
+
engine_return_type = output_data_slice.data_type().schema()
|
1096
|
+
python_return_type: type[T] | None = _get_data_slice_annotation_type(
|
1097
|
+
inspect.signature(self._flow_fn).return_annotation
|
1098
|
+
)
|
1099
|
+
result_decoder = make_engine_value_decoder(
|
1100
|
+
[], engine_return_type["type"], analyze_type_info(python_return_type)
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
return TransformFlowInfo(engine_flow, result_decoder)
|
1104
|
+
|
1105
|
+
def __str__(self) -> str:
|
1106
|
+
return str(self._flow_info.engine_flow)
|
1107
|
+
|
1108
|
+
def __repr__(self) -> str:
|
1109
|
+
return repr(self._flow_info.engine_flow)
|
1110
|
+
|
1111
|
+
def internal_flow(self) -> _engine.TransientFlow:
|
1112
|
+
"""
|
1113
|
+
Get the internal flow.
|
1114
|
+
"""
|
1115
|
+
return self._flow_info.engine_flow
|
1116
|
+
|
1117
|
+
def eval(self, *args: Any, **kwargs: Any) -> T:
|
1118
|
+
"""
|
1119
|
+
Evaluate the transform flow.
|
1120
|
+
"""
|
1121
|
+
return execution_context.run(self.eval_async(*args, **kwargs))
|
1122
|
+
|
1123
|
+
async def eval_async(self, *args: Any, **kwargs: Any) -> T:
|
1124
|
+
"""
|
1125
|
+
Evaluate the transform flow.
|
1126
|
+
"""
|
1127
|
+
flow_info = await self._flow_info_async()
|
1128
|
+
params = []
|
1129
|
+
for i, arg_info in enumerate(self._args_info):
|
1130
|
+
if i < len(args):
|
1131
|
+
arg = args[i]
|
1132
|
+
elif arg in kwargs:
|
1133
|
+
arg = kwargs[arg]
|
1134
|
+
else:
|
1135
|
+
raise ValueError(f"Parameter {arg} is not provided")
|
1136
|
+
params.append(arg_info.encoder(arg))
|
1137
|
+
engine_result = await flow_info.engine_flow.evaluate_async(params)
|
1138
|
+
return flow_info.result_decoder(engine_result)
|
1139
|
+
|
1140
|
+
|
1141
|
+
def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]]:
|
1142
|
+
"""
|
1143
|
+
A decorator to wrap the transform function.
|
1144
|
+
"""
|
1145
|
+
|
1146
|
+
def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]:
|
1147
|
+
_transform_flow = TransformFlow(fn)
|
1148
|
+
functools.update_wrapper(_transform_flow, fn)
|
1149
|
+
return _transform_flow
|
1150
|
+
|
1151
|
+
return _transform_flow_wrapper
|
1152
|
+
|
1153
|
+
|
1154
|
+
async def make_setup_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
|
1155
|
+
"""
|
1156
|
+
Make a bundle to setup flows with the given names.
|
1157
|
+
"""
|
1158
|
+
full_names = []
|
1159
|
+
for fl in flow_iter:
|
1160
|
+
await fl.internal_flow_async()
|
1161
|
+
full_names.append(fl.full_name)
|
1162
|
+
return SetupChangeBundle(_engine.make_setup_bundle(full_names))
|
1163
|
+
|
1164
|
+
|
1165
|
+
def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
|
1166
|
+
"""
|
1167
|
+
Make a bundle to setup flows with the given names.
|
1168
|
+
"""
|
1169
|
+
return execution_context.run(make_setup_bundle_async(flow_iter))
|
1170
|
+
|
1171
|
+
|
1172
|
+
async def make_drop_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
|
1173
|
+
"""
|
1174
|
+
Make a bundle to drop flows with the given names.
|
1175
|
+
"""
|
1176
|
+
full_names = []
|
1177
|
+
for fl in flow_iter:
|
1178
|
+
await fl.internal_flow_async()
|
1179
|
+
full_names.append(fl.full_name)
|
1180
|
+
return SetupChangeBundle(_engine.make_drop_bundle(full_names))
|
1181
|
+
|
1182
|
+
|
1183
|
+
def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
|
1184
|
+
"""
|
1185
|
+
Make a bundle to drop flows with the given names.
|
1186
|
+
"""
|
1187
|
+
return execution_context.run(make_drop_bundle_async(flow_iter))
|
1188
|
+
|
1189
|
+
|
1190
|
+
def setup_all_flows(report_to_stdout: bool = False) -> None:
|
1191
|
+
"""
|
1192
|
+
Setup all flows registered in the current process.
|
1193
|
+
"""
|
1194
|
+
with _flows_lock:
|
1195
|
+
flow_list = list(_flows.values())
|
1196
|
+
make_setup_bundle(flow_list).describe_and_apply(report_to_stdout=report_to_stdout)
|
1197
|
+
|
1198
|
+
|
1199
|
+
def drop_all_flows(report_to_stdout: bool = False) -> None:
|
1200
|
+
"""
|
1201
|
+
Drop all flows registered in the current process.
|
1202
|
+
"""
|
1203
|
+
with _flows_lock:
|
1204
|
+
flow_list = list(_flows.values())
|
1205
|
+
make_drop_bundle(flow_list).describe_and_apply(report_to_stdout=report_to_stdout)
|