vgi-python 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vgi/__init__.py +152 -0
- vgi/_duckdb.py +62 -0
- vgi/_storage_profile.py +132 -0
- vgi/_test_fixtures/__init__.py +20 -0
- vgi/_test_fixtures/accumulate/__init__.py +19 -0
- vgi/_test_fixtures/accumulate/worker.py +762 -0
- vgi/_test_fixtures/aggregate/__init__.py +62 -0
- vgi/_test_fixtures/aggregate/_common.py +21 -0
- vgi/_test_fixtures/aggregate/basic.py +232 -0
- vgi/_test_fixtures/aggregate/dynamic.py +409 -0
- vgi/_test_fixtures/aggregate/generic.py +86 -0
- vgi/_test_fixtures/aggregate/listagg.py +71 -0
- vgi/_test_fixtures/aggregate/percentile.py +107 -0
- vgi/_test_fixtures/aggregate/streaming.py +192 -0
- vgi/_test_fixtures/aggregate/varargs.py +75 -0
- vgi/_test_fixtures/aggregate/window.py +380 -0
- vgi/_test_fixtures/attach_options.py +308 -0
- vgi/_test_fixtures/bad_protocol.py +62 -0
- vgi/_test_fixtures/cancellable.py +336 -0
- vgi/_test_fixtures/catalog.py +813 -0
- vgi/_test_fixtures/http_server.py +394 -0
- vgi/_test_fixtures/nest_tensor.py +614 -0
- vgi/_test_fixtures/orchard_catalog.py +47 -0
- vgi/_test_fixtures/projection_repro/__init__.py +6 -0
- vgi/_test_fixtures/projection_repro/worker.py +454 -0
- vgi/_test_fixtures/scalar/__init__.py +116 -0
- vgi/_test_fixtures/scalar/_common.py +69 -0
- vgi/_test_fixtures/scalar/arithmetic.py +321 -0
- vgi/_test_fixtures/scalar/binary.py +120 -0
- vgi/_test_fixtures/scalar/formatting.py +176 -0
- vgi/_test_fixtures/scalar/geo.py +300 -0
- vgi/_test_fixtures/scalar/null_handling.py +107 -0
- vgi/_test_fixtures/scalar/random_demo.py +171 -0
- vgi/_test_fixtures/scalar/settings_secrets.py +102 -0
- vgi/_test_fixtures/scalar/type_info.py +219 -0
- vgi/_test_fixtures/schema_reconcile/__init__.py +29 -0
- vgi/_test_fixtures/schema_reconcile/worker.py +653 -0
- vgi/_test_fixtures/simple_writable.py +793 -0
- vgi/_test_fixtures/table/__init__.py +221 -0
- vgi/_test_fixtures/table/_common.py +162 -0
- vgi/_test_fixtures/table/batch_index.py +283 -0
- vgi/_test_fixtures/table/batch_index_broken.py +200 -0
- vgi/_test_fixtures/table/catalog_scans.py +162 -0
- vgi/_test_fixtures/table/filters.py +1005 -0
- vgi/_test_fixtures/table/late_materialization.py +249 -0
- vgi/_test_fixtures/table/make_series.py +273 -0
- vgi/_test_fixtures/table/misc.py +499 -0
- vgi/_test_fixtures/table/order_modes.py +164 -0
- vgi/_test_fixtures/table/pairs.py +437 -0
- vgi/_test_fixtures/table/partition_columns.py +472 -0
- vgi/_test_fixtures/table/partition_columns_broken.py +304 -0
- vgi/_test_fixtures/table/profiling_example.py +195 -0
- vgi/_test_fixtures/table/required_filters.py +234 -0
- vgi/_test_fixtures/table/sequence.py +710 -0
- vgi/_test_fixtures/table/settings.py +426 -0
- vgi/_test_fixtures/table/transaction_storage.py +162 -0
- vgi/_test_fixtures/table/tt_pushdown.py +191 -0
- vgi/_test_fixtures/table/versioned.py +230 -0
- vgi/_test_fixtures/table_in_out.py +1392 -0
- vgi/_test_fixtures/versioned.py +155 -0
- vgi/_test_fixtures/versioned_tables.py +595 -0
- vgi/_test_fixtures/worker.py +1631 -0
- vgi/_test_fixtures/writable/__init__.py +8 -0
- vgi/_test_fixtures/writable/generic.py +236 -0
- vgi/_test_fixtures/writable/table.py +149 -0
- vgi/_test_fixtures/writable/worker.py +1148 -0
- vgi/aggregate_function.py +607 -0
- vgi/argument_spec.py +472 -0
- vgi/arguments.py +1747 -0
- vgi/auth.py +55 -0
- vgi/catalog/__init__.py +88 -0
- vgi/catalog/attach_option.py +206 -0
- vgi/catalog/catalog_interface.py +2767 -0
- vgi/catalog/descriptors.py +870 -0
- vgi/catalog/duckdb_statistics.py +377 -0
- vgi/catalog/secret_type.py +96 -0
- vgi/catalog/setting.py +253 -0
- vgi/catalog/storage.py +372 -0
- vgi/client/__init__.py +67 -0
- vgi/client/catalog_mixin.py +1251 -0
- vgi/client/cli.py +582 -0
- vgi/client/cli_catalog.py +182 -0
- vgi/client/cli_schema.py +270 -0
- vgi/client/cli_table.py +907 -0
- vgi/client/cli_transaction.py +97 -0
- vgi/client/cli_utils.py +441 -0
- vgi/client/cli_view.py +303 -0
- vgi/client/client.py +2183 -0
- vgi/exceptions.py +205 -0
- vgi/function.py +245 -0
- vgi/function_storage.py +1636 -0
- vgi/function_storage_azure_sql.py +922 -0
- vgi/function_storage_cf_do.py +740 -0
- vgi/http/__init__.py +25 -0
- vgi/http/demo_storage.py +212 -0
- vgi/http/worker_page.py +1252 -0
- vgi/invocation.py +154 -0
- vgi/logging_config.py +93 -0
- vgi/meta_worker.py +661 -0
- vgi/metadata.py +1403 -0
- vgi/otel.py +406 -0
- vgi/protocol.py +2418 -0
- vgi/protocol_version.txt +1 -0
- vgi/py.typed +0 -0
- vgi/scalar_function.py +1211 -0
- vgi/schema_utils.py +234 -0
- vgi/secret_protocol.py +124 -0
- vgi/secret_service.py +238 -0
- vgi/serve.py +769 -0
- vgi/table_buffering_function.py +443 -0
- vgi/table_filter_pushdown.py +1528 -0
- vgi/table_function.py +1130 -0
- vgi/table_in_out_function.py +383 -0
- vgi/transactor/__init__.py +24 -0
- vgi/transactor/_duckdb_compat.py +27 -0
- vgi/transactor/client.py +137 -0
- vgi/transactor/protocol.py +149 -0
- vgi/transactor/server.py +740 -0
- vgi/worker.py +4761 -0
- vgi_python-0.8.0.dist-info/METADATA +735 -0
- vgi_python-0.8.0.dist-info/RECORD +124 -0
- vgi_python-0.8.0.dist-info/WHEEL +4 -0
- vgi_python-0.8.0.dist-info/entry_points.txt +5 -0
- vgi_python-0.8.0.dist-info/licenses/LICENSE +134 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Aggregate-function fixtures.
|
|
4
|
+
|
|
5
|
+
Originally a single 1,179-line module; split into cohesive sub-modules and
|
|
6
|
+
re-exported here so existing import sites (worker.py, tests) keep working
|
|
7
|
+
unchanged.
|
|
8
|
+
|
|
9
|
+
* :mod:`._common` — shared SumState, ListAggState
|
|
10
|
+
* :mod:`.basic` — count, sum, avg, weighted_sum
|
|
11
|
+
* :mod:`.listagg` — list_agg (order-dependent string concatenation)
|
|
12
|
+
* :mod:`.percentile` — percentile (sorted-quantile demo)
|
|
13
|
+
* :mod:`.generic` — generic_sum (any-type aggregate)
|
|
14
|
+
* :mod:`.varargs` — sum_all (varargs over numeric columns)
|
|
15
|
+
* :mod:`.dynamic` — dynamic_aggregate, dynamic_ml_aggregate
|
|
16
|
+
(gated on VGI_WORKER_SUPPORTS_DYNAMIC_CODE)
|
|
17
|
+
* :mod:`.window` — window_sum, window_median, window_listagg
|
|
18
|
+
* :mod:`.streaming` — streaming_sum (streaming-partitioned protocol)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from vgi._test_fixtures.aggregate._common import ListAggState, SumState
|
|
22
|
+
from vgi._test_fixtures.aggregate.basic import (
|
|
23
|
+
AvgFunction,
|
|
24
|
+
CountFunction,
|
|
25
|
+
SumFunction,
|
|
26
|
+
WeightedSumFunction,
|
|
27
|
+
)
|
|
28
|
+
from vgi._test_fixtures.aggregate.dynamic import (
|
|
29
|
+
DynamicAggregateFunction,
|
|
30
|
+
DynamicMLAggregateFunction,
|
|
31
|
+
)
|
|
32
|
+
from vgi._test_fixtures.aggregate.generic import GenericSumFunction
|
|
33
|
+
from vgi._test_fixtures.aggregate.listagg import ListAggFunction
|
|
34
|
+
from vgi._test_fixtures.aggregate.percentile import PercentileFunction
|
|
35
|
+
from vgi._test_fixtures.aggregate.streaming import StreamingSumFunction
|
|
36
|
+
from vgi._test_fixtures.aggregate.varargs import SumAllFunction
|
|
37
|
+
from vgi._test_fixtures.aggregate.window import (
|
|
38
|
+
WindowListAggFunction,
|
|
39
|
+
WindowMedianFunction,
|
|
40
|
+
WindowSumBatchFunction,
|
|
41
|
+
WindowSumFunction,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"AvgFunction",
|
|
46
|
+
"CountFunction",
|
|
47
|
+
"DynamicAggregateFunction",
|
|
48
|
+
"DynamicMLAggregateFunction",
|
|
49
|
+
"GenericSumFunction",
|
|
50
|
+
"ListAggFunction",
|
|
51
|
+
"ListAggState",
|
|
52
|
+
"PercentileFunction",
|
|
53
|
+
"StreamingSumFunction",
|
|
54
|
+
"SumAllFunction",
|
|
55
|
+
"SumFunction",
|
|
56
|
+
"SumState",
|
|
57
|
+
"WeightedSumFunction",
|
|
58
|
+
"WindowListAggFunction",
|
|
59
|
+
"WindowMedianFunction",
|
|
60
|
+
"WindowSumBatchFunction",
|
|
61
|
+
"WindowSumFunction",
|
|
62
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Shared aggregate state classes used across multiple submodules."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Annotated
|
|
9
|
+
|
|
10
|
+
import pyarrow as pa
|
|
11
|
+
from vgi_rpc import ArrowSerializableDataclass, ArrowType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(kw_only=True)
|
|
15
|
+
class SumState(ArrowSerializableDataclass):
|
|
16
|
+
total: Annotated[int, ArrowType(pa.int64())] = 0
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(kw_only=True)
|
|
20
|
+
class ListAggState(ArrowSerializableDataclass):
|
|
21
|
+
values: Annotated[str, ArrowType(pa.string())] = ""
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# Copyright 2025, 2026 Query Farm LLC - https://query.farm
|
|
2
|
+
|
|
3
|
+
"""Basic aggregate fixtures: count, sum, avg, weighted_sum."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Annotated
|
|
9
|
+
|
|
10
|
+
import pyarrow as pa
|
|
11
|
+
from vgi_rpc import ArrowSerializableDataclass, ArrowType
|
|
12
|
+
|
|
13
|
+
from vgi._test_fixtures.aggregate._common import SumState
|
|
14
|
+
from vgi.aggregate_function import AggregateFunction
|
|
15
|
+
from vgi.arguments import Param, Returns
|
|
16
|
+
from vgi.metadata import DistinctDependence, NullHandling, OrderDependence
|
|
17
|
+
from vgi.table_function import ProcessParams
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(kw_only=True)
|
|
21
|
+
class CountState(ArrowSerializableDataclass):
|
|
22
|
+
count: Annotated[int, ArrowType(pa.int64())] = 0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(kw_only=True)
|
|
26
|
+
class AvgState(ArrowSerializableDataclass):
|
|
27
|
+
total: Annotated[float, ArrowType(pa.float64())] = 0.0
|
|
28
|
+
count: Annotated[int, ArrowType(pa.int64())] = 0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(kw_only=True)
|
|
32
|
+
class WeightedSumState(ArrowSerializableDataclass):
|
|
33
|
+
total: Annotated[float, ArrowType(pa.float64())] = 0.0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CountFunction(AggregateFunction[CountState]):
|
|
37
|
+
"""Count aggregate — nullary (no input columns).
|
|
38
|
+
|
|
39
|
+
SQL: ``SELECT vgi_count() FROM t``
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
class Meta:
|
|
43
|
+
name = "vgi_count"
|
|
44
|
+
description = "Count rows"
|
|
45
|
+
null_handling = NullHandling.SPECIAL
|
|
46
|
+
order_dependent = OrderDependence.NOT_ORDER_DEPENDENT
|
|
47
|
+
distinct_dependent = DistinctDependence.NOT_DISTINCT_DEPENDENT
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def initial_state(cls, params: ProcessParams[None]) -> CountState:
|
|
51
|
+
return CountState()
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def update(
|
|
55
|
+
cls,
|
|
56
|
+
states: dict[int, CountState],
|
|
57
|
+
group_ids: pa.Int64Array,
|
|
58
|
+
) -> None:
|
|
59
|
+
table = pa.table({"gid": group_ids})
|
|
60
|
+
grouped = table.group_by("gid").aggregate([("gid", "count")])
|
|
61
|
+
for i in range(grouped.num_rows):
|
|
62
|
+
gid: int = grouped.column("gid")[i].as_py()
|
|
63
|
+
cnt: int = grouped.column("gid_count")[i].as_py()
|
|
64
|
+
states[gid] = CountState(count=states[gid].count + cnt)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def combine(cls, source: CountState, target: CountState, params: ProcessParams[None]) -> CountState:
|
|
68
|
+
return CountState(count=source.count + target.count)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def finalize(
|
|
72
|
+
cls,
|
|
73
|
+
group_ids: pa.Int64Array,
|
|
74
|
+
states: dict[int, CountState],
|
|
75
|
+
params: ProcessParams[None],
|
|
76
|
+
) -> Annotated[pa.RecordBatch, Returns(pa.int64())]:
|
|
77
|
+
results = [s.count if (s := states[gid.as_py()]) is not None else 0 for gid in group_ids]
|
|
78
|
+
return pa.record_batch({"result": pa.array(results, type=pa.int64())})
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SumFunction(AggregateFunction[SumState]):
|
|
82
|
+
"""Sum aggregate — single int64 input.
|
|
83
|
+
|
|
84
|
+
SQL: ``SELECT vgi_sum(value) FROM t GROUP BY category``
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
class Meta:
|
|
88
|
+
name = "vgi_sum"
|
|
89
|
+
description = "Sum integer values"
|
|
90
|
+
null_handling = NullHandling.DEFAULT
|
|
91
|
+
order_dependent = OrderDependence.NOT_ORDER_DEPENDENT
|
|
92
|
+
distinct_dependent = DistinctDependence.NOT_DISTINCT_DEPENDENT
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def initial_state(cls, params: ProcessParams[None]) -> SumState:
|
|
96
|
+
return SumState()
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def update(
|
|
100
|
+
cls,
|
|
101
|
+
states: dict[int, SumState],
|
|
102
|
+
group_ids: pa.Int64Array,
|
|
103
|
+
value: Annotated[pa.Int64Array, Param(doc="Column to sum")],
|
|
104
|
+
) -> None:
|
|
105
|
+
table = pa.table({"gid": group_ids, "value": value})
|
|
106
|
+
grouped = table.group_by("gid").aggregate([("value", "sum")])
|
|
107
|
+
for i in range(grouped.num_rows):
|
|
108
|
+
gid: int = grouped.column("gid")[i].as_py()
|
|
109
|
+
val = grouped.column("value_sum")[i].as_py()
|
|
110
|
+
if val is not None:
|
|
111
|
+
states[gid] = SumState(total=states[gid].total + val)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def combine(cls, source: SumState, target: SumState, params: ProcessParams[None]) -> SumState:
|
|
115
|
+
return SumState(total=source.total + target.total)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def finalize(
|
|
119
|
+
cls,
|
|
120
|
+
group_ids: pa.Int64Array,
|
|
121
|
+
states: dict[int, SumState],
|
|
122
|
+
params: ProcessParams[None],
|
|
123
|
+
) -> Annotated[pa.RecordBatch, Returns(pa.int64())]:
|
|
124
|
+
results = [s.total if (s := states[gid.as_py()]) is not None else None for gid in group_ids]
|
|
125
|
+
return pa.record_batch({"result": pa.array(results, type=pa.int64())})
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class AvgFunction(AggregateFunction[AvgState]):
|
|
129
|
+
"""Average aggregate — two-field state (sum + count).
|
|
130
|
+
|
|
131
|
+
SQL: ``SELECT vgi_avg(value) FROM t GROUP BY category``
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
class Meta:
|
|
135
|
+
name = "vgi_avg"
|
|
136
|
+
description = "Average of integer values"
|
|
137
|
+
null_handling = NullHandling.DEFAULT
|
|
138
|
+
order_dependent = OrderDependence.NOT_ORDER_DEPENDENT
|
|
139
|
+
distinct_dependent = DistinctDependence.NOT_DISTINCT_DEPENDENT
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def initial_state(cls, params: ProcessParams[None]) -> AvgState:
|
|
143
|
+
return AvgState()
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def update(
|
|
147
|
+
cls,
|
|
148
|
+
states: dict[int, AvgState],
|
|
149
|
+
group_ids: pa.Int64Array,
|
|
150
|
+
value: Annotated[pa.Int64Array, Param(doc="Column to average")],
|
|
151
|
+
) -> None:
|
|
152
|
+
table = pa.table({"gid": group_ids, "value": value})
|
|
153
|
+
grouped = table.group_by("gid").aggregate([("value", "sum"), ("value", "count")])
|
|
154
|
+
for i in range(grouped.num_rows):
|
|
155
|
+
gid: int = grouped.column("gid")[i].as_py()
|
|
156
|
+
val_sum = grouped.column("value_sum")[i].as_py()
|
|
157
|
+
val_count: int = grouped.column("value_count")[i].as_py()
|
|
158
|
+
s = states[gid]
|
|
159
|
+
states[gid] = AvgState(
|
|
160
|
+
total=s.total + (val_sum if val_sum is not None else 0.0),
|
|
161
|
+
count=s.count + val_count,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def combine(cls, source: AvgState, target: AvgState, params: ProcessParams[None]) -> AvgState:
|
|
166
|
+
return AvgState(total=source.total + target.total, count=source.count + target.count)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def finalize(
|
|
170
|
+
cls,
|
|
171
|
+
group_ids: pa.Int64Array,
|
|
172
|
+
states: dict[int, AvgState],
|
|
173
|
+
params: ProcessParams[None],
|
|
174
|
+
) -> Annotated[pa.RecordBatch, Returns(pa.float64())]:
|
|
175
|
+
results = []
|
|
176
|
+
for gid in group_ids:
|
|
177
|
+
s = states[gid.as_py()]
|
|
178
|
+
results.append(s.total / s.count if s is not None and s.count > 0 else None)
|
|
179
|
+
return pa.record_batch({"result": pa.array(results, type=pa.float64())})
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class WeightedSumFunction(AggregateFunction[WeightedSumState]):
|
|
183
|
+
"""Weighted sum aggregate — multi-input (value + weight).
|
|
184
|
+
|
|
185
|
+
SQL: ``SELECT vgi_weighted_sum(value, weight) FROM t GROUP BY category``
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
class Meta:
|
|
189
|
+
name = "vgi_weighted_sum"
|
|
190
|
+
description = "Weighted sum of values"
|
|
191
|
+
null_handling = NullHandling.DEFAULT
|
|
192
|
+
order_dependent = OrderDependence.NOT_ORDER_DEPENDENT
|
|
193
|
+
distinct_dependent = DistinctDependence.NOT_DISTINCT_DEPENDENT
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def initial_state(cls, params: ProcessParams[None]) -> WeightedSumState:
|
|
197
|
+
return WeightedSumState()
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def update(
|
|
201
|
+
cls,
|
|
202
|
+
states: dict[int, WeightedSumState],
|
|
203
|
+
group_ids: pa.Int64Array,
|
|
204
|
+
value: Annotated[pa.DoubleArray, Param(doc="Values to sum")],
|
|
205
|
+
weight: Annotated[pa.DoubleArray, Param(doc="Weights")],
|
|
206
|
+
) -> None:
|
|
207
|
+
import pyarrow.compute as pc
|
|
208
|
+
|
|
209
|
+
products = pc.multiply(value, weight)
|
|
210
|
+
table = pa.table({"gid": group_ids, "product": products})
|
|
211
|
+
grouped = table.group_by("gid").aggregate([("product", "sum")])
|
|
212
|
+
for i in range(grouped.num_rows):
|
|
213
|
+
gid: int = grouped.column("gid")[i].as_py()
|
|
214
|
+
val = grouped.column("product_sum")[i].as_py()
|
|
215
|
+
if val is not None:
|
|
216
|
+
states[gid] = WeightedSumState(total=states[gid].total + val)
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def combine(
|
|
220
|
+
cls, source: WeightedSumState, target: WeightedSumState, params: ProcessParams[None]
|
|
221
|
+
) -> WeightedSumState:
|
|
222
|
+
return WeightedSumState(total=source.total + target.total)
|
|
223
|
+
|
|
224
|
+
@classmethod
|
|
225
|
+
def finalize(
|
|
226
|
+
cls,
|
|
227
|
+
group_ids: pa.Int64Array,
|
|
228
|
+
states: dict[int, WeightedSumState],
|
|
229
|
+
params: ProcessParams[None],
|
|
230
|
+
) -> Annotated[pa.RecordBatch, Returns(pa.float64())]:
|
|
231
|
+
results = [s.total if (s := states[gid.as_py()]) is not None else None for gid in group_ids]
|
|
232
|
+
return pa.record_batch({"result": pa.array(results, type=pa.float64())})
|