palimpzest 0.8.7__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +13 -4
- palimpzest/core/data/dataset.py +75 -5
- palimpzest/core/elements/groupbysig.py +5 -1
- palimpzest/core/elements/records.py +16 -7
- palimpzest/core/lib/schemas.py +26 -3
- palimpzest/core/models.py +4 -4
- palimpzest/prompts/aggregate_prompts.py +99 -0
- palimpzest/prompts/prompt_factory.py +162 -75
- palimpzest/prompts/utils.py +38 -1
- palimpzest/prompts/validator.py +24 -24
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +8 -8
- palimpzest/query/execution/mab_execution_strategy.py +30 -11
- palimpzest/query/execution/parallel_execution_strategy.py +31 -7
- palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
- palimpzest/query/generators/generators.py +9 -7
- palimpzest/query/operators/__init__.py +10 -6
- palimpzest/query/operators/aggregate.py +394 -10
- palimpzest/query/operators/convert.py +1 -1
- palimpzest/query/operators/join.py +279 -23
- palimpzest/query/operators/logical.py +36 -11
- palimpzest/query/operators/mixture_of_agents.py +3 -1
- palimpzest/query/operators/physical.py +5 -2
- palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
- palimpzest/query/optimizer/__init__.py +11 -3
- palimpzest/query/optimizer/cost_model.py +5 -5
- palimpzest/query/optimizer/optimizer.py +3 -2
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/rules.py +73 -13
- palimpzest/query/optimizer/tasks.py +4 -4
- palimpzest/utils/progress.py +19 -17
- palimpzest/validator/validator.py +7 -7
- {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
- {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/RECORD +37 -36
- {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import time
|
|
4
|
-
|
|
5
|
-
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from palimpzest.constants import (
|
|
8
|
+
MODEL_CARDS,
|
|
9
|
+
NAIVE_EST_NUM_GROUPS,
|
|
10
|
+
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
11
|
+
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
12
|
+
AggFunc,
|
|
13
|
+
Model,
|
|
14
|
+
PromptStrategy,
|
|
15
|
+
)
|
|
6
16
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
7
17
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
8
|
-
from palimpzest.core.lib.schemas import Average, Count
|
|
18
|
+
from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
|
|
9
19
|
from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
|
|
20
|
+
from palimpzest.query.generators.generators import Generator
|
|
10
21
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
22
|
|
|
12
23
|
|
|
@@ -58,6 +69,16 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
58
69
|
return 0
|
|
59
70
|
elif func.lower() == "average":
|
|
60
71
|
return (0, 0)
|
|
72
|
+
elif func.lower() == "sum":
|
|
73
|
+
return 0
|
|
74
|
+
elif func.lower() == "min":
|
|
75
|
+
return float("inf")
|
|
76
|
+
elif func.lower() == "max":
|
|
77
|
+
return float("-inf")
|
|
78
|
+
elif func.lower() == "list":
|
|
79
|
+
return []
|
|
80
|
+
elif func.lower() == "set":
|
|
81
|
+
return set()
|
|
61
82
|
else:
|
|
62
83
|
raise Exception("Unknown agg function " + func)
|
|
63
84
|
|
|
@@ -66,16 +87,34 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
66
87
|
if func.lower() == "count":
|
|
67
88
|
return state + 1
|
|
68
89
|
elif func.lower() == "average":
|
|
69
|
-
|
|
90
|
+
sum_, cnt = state
|
|
91
|
+
if val is None:
|
|
92
|
+
return (sum_, cnt)
|
|
93
|
+
return (sum_ + val, cnt + 1)
|
|
94
|
+
elif func.lower() == "sum":
|
|
70
95
|
if val is None:
|
|
71
|
-
return
|
|
72
|
-
return
|
|
96
|
+
return state
|
|
97
|
+
return state + sum(val) if isinstance(val, list) else state + val
|
|
98
|
+
elif func.lower() == "min":
|
|
99
|
+
if val is None:
|
|
100
|
+
return state
|
|
101
|
+
return min(state, min(val) if isinstance(val, list) else val)
|
|
102
|
+
elif func.lower() == "max":
|
|
103
|
+
if val is None:
|
|
104
|
+
return state
|
|
105
|
+
return max(state, max(val) if isinstance(val, list) else val)
|
|
106
|
+
elif func.lower() == "list":
|
|
107
|
+
state.append(val)
|
|
108
|
+
return state
|
|
109
|
+
elif func.lower() == "set":
|
|
110
|
+
state.add(val)
|
|
111
|
+
return state
|
|
73
112
|
else:
|
|
74
113
|
raise Exception("Unknown agg function " + func)
|
|
75
114
|
|
|
76
115
|
@staticmethod
|
|
77
116
|
def agg_final(func, state):
|
|
78
|
-
if func.lower()
|
|
117
|
+
if func.lower() in ["count", "sum", "min", "max", "list", "set"]:
|
|
79
118
|
return state
|
|
80
119
|
elif func.lower() == "average":
|
|
81
120
|
sum, cnt = state
|
|
@@ -156,12 +195,17 @@ class AverageAggregateOp(AggregateOp):
|
|
|
156
195
|
|
|
157
196
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
158
197
|
# enforce that output schema is correct
|
|
159
|
-
assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
|
|
198
|
+
assert kwargs["output_schema"].model_fields.keys() == Average.model_fields.keys(), "AverageAggregateOp requires output_schema to be Average"
|
|
160
199
|
|
|
161
200
|
# enforce that input schema is a single numeric field
|
|
162
201
|
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
163
202
|
assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
|
|
164
|
-
numeric_field_types = [
|
|
203
|
+
numeric_field_types = [
|
|
204
|
+
bool, int, float, int | float,
|
|
205
|
+
bool | None, int | None, float | None, int | float | None,
|
|
206
|
+
bool | Any, int | Any, float | Any, int | float | Any,
|
|
207
|
+
bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
|
|
208
|
+
]
|
|
165
209
|
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
166
210
|
assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
167
211
|
|
|
@@ -225,12 +269,88 @@ class AverageAggregateOp(AggregateOp):
|
|
|
225
269
|
return DataRecordSet([dr], [record_op_stats])
|
|
226
270
|
|
|
227
271
|
|
|
272
|
+
class SumAggregateOp(AggregateOp):
|
|
273
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
274
|
+
|
|
275
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
276
|
+
# enforce that output schema is correct
|
|
277
|
+
assert kwargs["output_schema"].model_fields.keys() == Sum.model_fields.keys(), "SumAggregateOp requires output_schema to be Sum"
|
|
278
|
+
|
|
279
|
+
# enforce that input schema is a single numeric field
|
|
280
|
+
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
281
|
+
assert len(input_field_types) == 1, "SumAggregateOp requires input_schema to have exactly one field"
|
|
282
|
+
numeric_field_types = [
|
|
283
|
+
bool, int, float, int | float,
|
|
284
|
+
bool | None, int | None, float | None, int | float | None,
|
|
285
|
+
bool | Any, int | Any, float | Any, int | float | Any,
|
|
286
|
+
bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
|
|
287
|
+
]
|
|
288
|
+
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
289
|
+
assert is_numeric, f"SumAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
290
|
+
|
|
291
|
+
# call parent constructor
|
|
292
|
+
super().__init__(*args, **kwargs)
|
|
293
|
+
self.agg_func = agg_func
|
|
294
|
+
|
|
295
|
+
def __str__(self):
|
|
296
|
+
op = super().__str__()
|
|
297
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
298
|
+
return op
|
|
299
|
+
|
|
300
|
+
def get_id_params(self):
|
|
301
|
+
id_params = super().get_id_params()
|
|
302
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
303
|
+
|
|
304
|
+
def get_op_params(self):
|
|
305
|
+
op_params = super().get_op_params()
|
|
306
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
307
|
+
|
|
308
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
309
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
310
|
+
return OperatorCostEstimates(
|
|
311
|
+
cardinality=1,
|
|
312
|
+
time_per_record=0,
|
|
313
|
+
cost_per_record=0,
|
|
314
|
+
quality=1.0,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
318
|
+
start_time = time.time()
|
|
319
|
+
|
|
320
|
+
# NOTE: we currently do not guarantee that input values conform to their specified type;
|
|
321
|
+
# as a result, we simply omit any values which do not parse to a float from the average
|
|
322
|
+
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
323
|
+
# has a single field which is numeric in nature; in the future we may want to have a
|
|
324
|
+
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
325
|
+
summation = 0
|
|
326
|
+
for candidate in candidates:
|
|
327
|
+
with contextlib.suppress(Exception):
|
|
328
|
+
summation += float(list(candidate.to_dict().values())[0])
|
|
329
|
+
data_item = Sum(sum=summation)
|
|
330
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
331
|
+
|
|
332
|
+
# create RecordOpStats object
|
|
333
|
+
record_op_stats = RecordOpStats(
|
|
334
|
+
record_id=dr._id,
|
|
335
|
+
record_parent_ids=dr._parent_ids,
|
|
336
|
+
record_source_indices=dr._source_indices,
|
|
337
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
338
|
+
full_op_id=self.get_full_op_id(),
|
|
339
|
+
logical_op_id=self.logical_op_id,
|
|
340
|
+
op_name=self.op_name(),
|
|
341
|
+
time_per_record=time.time() - start_time,
|
|
342
|
+
cost_per_record=0.0,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
346
|
+
|
|
347
|
+
|
|
228
348
|
class CountAggregateOp(AggregateOp):
|
|
229
349
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
230
350
|
|
|
231
351
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
232
352
|
# enforce that output schema is correct
|
|
233
|
-
assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
|
|
353
|
+
assert kwargs["output_schema"].model_fields.keys() == Count.model_fields.keys(), "CountAggregateOp requires output_schema to be Count"
|
|
234
354
|
|
|
235
355
|
# call parent constructor
|
|
236
356
|
super().__init__(*args, **kwargs)
|
|
@@ -280,3 +400,267 @@ class CountAggregateOp(AggregateOp):
|
|
|
280
400
|
)
|
|
281
401
|
|
|
282
402
|
return DataRecordSet([dr], [record_op_stats])
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class MinAggregateOp(AggregateOp):
|
|
406
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
407
|
+
|
|
408
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
409
|
+
# enforce that output schema is correct
|
|
410
|
+
assert kwargs["output_schema"].model_fields.keys() == Min.model_fields.keys(), "MinAggregateOp requires output_schema to be Min"
|
|
411
|
+
|
|
412
|
+
# call parent constructor
|
|
413
|
+
super().__init__(*args, **kwargs)
|
|
414
|
+
self.agg_func = agg_func
|
|
415
|
+
|
|
416
|
+
def __str__(self):
|
|
417
|
+
op = super().__str__()
|
|
418
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
419
|
+
return op
|
|
420
|
+
|
|
421
|
+
def get_id_params(self):
|
|
422
|
+
id_params = super().get_id_params()
|
|
423
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
424
|
+
|
|
425
|
+
def get_op_params(self):
|
|
426
|
+
op_params = super().get_op_params()
|
|
427
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
428
|
+
|
|
429
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
430
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
431
|
+
return OperatorCostEstimates(
|
|
432
|
+
cardinality=1,
|
|
433
|
+
time_per_record=0,
|
|
434
|
+
cost_per_record=0,
|
|
435
|
+
quality=1.0,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
439
|
+
start_time = time.time()
|
|
440
|
+
|
|
441
|
+
# create new DataRecord
|
|
442
|
+
min = float("inf")
|
|
443
|
+
for candidate in candidates:
|
|
444
|
+
try: # noqa: SIM105
|
|
445
|
+
min = min(float(list(candidate.to_dict().values())[0]), min)
|
|
446
|
+
except Exception:
|
|
447
|
+
pass
|
|
448
|
+
data_item = Min(min=min if min != float("inf") else None)
|
|
449
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
450
|
+
|
|
451
|
+
# create RecordOpStats object
|
|
452
|
+
record_op_stats = RecordOpStats(
|
|
453
|
+
record_id=dr.id,
|
|
454
|
+
record_parent_ids=dr.parent_ids,
|
|
455
|
+
record_source_indices=dr.source_indices,
|
|
456
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
457
|
+
full_op_id=self.get_full_op_id(),
|
|
458
|
+
logical_op_id=self.logical_op_id,
|
|
459
|
+
op_name=self.op_name(),
|
|
460
|
+
time_per_record=time.time() - start_time,
|
|
461
|
+
cost_per_record=0.0,
|
|
462
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class MaxAggregateOp(AggregateOp):
|
|
469
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
470
|
+
|
|
471
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
472
|
+
# enforce that output schema is correct
|
|
473
|
+
assert kwargs["output_schema"].model_fields.keys() == Max.model_fields.keys(), "MaxAggregateOp requires output_schema to be Max"
|
|
474
|
+
|
|
475
|
+
# call parent constructor
|
|
476
|
+
super().__init__(*args, **kwargs)
|
|
477
|
+
self.agg_func = agg_func
|
|
478
|
+
|
|
479
|
+
def __str__(self):
|
|
480
|
+
op = super().__str__()
|
|
481
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
482
|
+
return op
|
|
483
|
+
|
|
484
|
+
def get_id_params(self):
|
|
485
|
+
id_params = super().get_id_params()
|
|
486
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
487
|
+
|
|
488
|
+
def get_op_params(self):
|
|
489
|
+
op_params = super().get_op_params()
|
|
490
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
491
|
+
|
|
492
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
493
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
494
|
+
return OperatorCostEstimates(
|
|
495
|
+
cardinality=1,
|
|
496
|
+
time_per_record=0,
|
|
497
|
+
cost_per_record=0,
|
|
498
|
+
quality=1.0,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
502
|
+
start_time = time.time()
|
|
503
|
+
|
|
504
|
+
# create new DataRecord
|
|
505
|
+
|
|
506
|
+
max = float("-inf")
|
|
507
|
+
for candidate in candidates:
|
|
508
|
+
try: # noqa: SIM105
|
|
509
|
+
max = max(float(list(candidate.to_dict().values())[0]), max)
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
512
|
+
data_item = Max(max=max if max != float("-inf") else None)
|
|
513
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
514
|
+
|
|
515
|
+
# create RecordOpStats object
|
|
516
|
+
record_op_stats = RecordOpStats(
|
|
517
|
+
record_id=dr.id,
|
|
518
|
+
record_parent_ids=dr.parent_ids,
|
|
519
|
+
record_source_indices=dr.source_indices,
|
|
520
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
521
|
+
full_op_id=self.get_full_op_id(),
|
|
522
|
+
logical_op_id=self.logical_op_id,
|
|
523
|
+
op_name=self.op_name(),
|
|
524
|
+
time_per_record=time.time() - start_time,
|
|
525
|
+
cost_per_record=0.0,
|
|
526
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
class SemanticAggregate(AggregateOp):
|
|
533
|
+
|
|
534
|
+
def __init__(self, agg_str: str, model: Model, prompt_strategy: PromptStrategy = PromptStrategy.AGG, reasoning_effort: str | None = None, *args, **kwargs):
|
|
535
|
+
# call parent constructor
|
|
536
|
+
super().__init__(*args, **kwargs)
|
|
537
|
+
self.agg_str = agg_str
|
|
538
|
+
self.model = model
|
|
539
|
+
self.prompt_strategy = prompt_strategy
|
|
540
|
+
self.reasoning_effort = reasoning_effort
|
|
541
|
+
if model is not None:
|
|
542
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base)
|
|
543
|
+
|
|
544
|
+
def __str__(self):
|
|
545
|
+
op = super().__str__()
|
|
546
|
+
op += f" Prompt Strategy: {self.prompt_strategy}\n"
|
|
547
|
+
op += f" Reasoning Effort: {self.reasoning_effort}\n"
|
|
548
|
+
op += f" Agg: {str(self.agg_str)}\n"
|
|
549
|
+
return op
|
|
550
|
+
|
|
551
|
+
def get_id_params(self):
|
|
552
|
+
id_params = super().get_id_params()
|
|
553
|
+
id_params = {
|
|
554
|
+
"agg_str": self.agg_str,
|
|
555
|
+
"model": None if self.model is None else self.model.value,
|
|
556
|
+
"prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
|
|
557
|
+
"reasoning_effort": self.reasoning_effort,
|
|
558
|
+
**id_params,
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
return id_params
|
|
562
|
+
|
|
563
|
+
def get_op_params(self):
|
|
564
|
+
op_params = super().get_op_params()
|
|
565
|
+
op_params = {
|
|
566
|
+
"agg_str": self.agg_str,
|
|
567
|
+
"model": self.model,
|
|
568
|
+
"prompt_strategy": self.prompt_strategy,
|
|
569
|
+
"reasoning_effort": self.reasoning_effort,
|
|
570
|
+
**op_params,
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
return op_params
|
|
574
|
+
|
|
575
|
+
def get_model_name(self) -> str:
|
|
576
|
+
return self.model.value
|
|
577
|
+
|
|
578
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
579
|
+
"""
|
|
580
|
+
Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
|
|
581
|
+
assume the use of a single LLM call for each input record. Child classes of LLMConvert
|
|
582
|
+
may call this function through super() and adjust these estimates as needed (or they can
|
|
583
|
+
completely override this function).
|
|
584
|
+
"""
|
|
585
|
+
# estimate number of input and output tokens from source
|
|
586
|
+
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS * source_op_cost_estimates.cardinality
|
|
587
|
+
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
588
|
+
|
|
589
|
+
# get est. of conversion time per record from model card;
|
|
590
|
+
model_name = self.model.value
|
|
591
|
+
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
592
|
+
|
|
593
|
+
# get est. of conversion cost (in USD) per record from model card
|
|
594
|
+
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
595
|
+
if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
|
|
596
|
+
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
597
|
+
|
|
598
|
+
model_conversion_usd_per_record = (
|
|
599
|
+
usd_per_input_token * est_num_input_tokens
|
|
600
|
+
+ MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# estimate quality of output based on the strength of the model being used
|
|
604
|
+
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
605
|
+
|
|
606
|
+
return OperatorCostEstimates(
|
|
607
|
+
cardinality=1.0,
|
|
608
|
+
time_per_record=model_conversion_time_per_record,
|
|
609
|
+
cost_per_record=model_conversion_usd_per_record,
|
|
610
|
+
quality=quality,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
614
|
+
start_time = time.time()
|
|
615
|
+
|
|
616
|
+
# TODO: if candidates is an empty list, return an empty DataRecordSet
|
|
617
|
+
if len(candidates) == 0:
|
|
618
|
+
return DataRecordSet([], [])
|
|
619
|
+
|
|
620
|
+
# get the set of input fields to use for the operation
|
|
621
|
+
input_fields = self.get_input_fields()
|
|
622
|
+
|
|
623
|
+
# get the set of output fields to use for the operation
|
|
624
|
+
fields_to_generate = self.get_fields_to_generate(candidates[0])
|
|
625
|
+
fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
|
|
626
|
+
|
|
627
|
+
# construct kwargs for generation
|
|
628
|
+
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema, "agg_instruction": self.agg_str}
|
|
629
|
+
|
|
630
|
+
# generate outputs for all fields in a single query
|
|
631
|
+
field_answers, _, generation_stats, _ = self.generator(candidates, fields, **gen_kwargs)
|
|
632
|
+
assert all([field in field_answers for field in fields]), "Not all fields were generated!"
|
|
633
|
+
|
|
634
|
+
# construct data record for the output
|
|
635
|
+
field, value = fields_to_generate[0], field_answers[fields_to_generate[0]][0]
|
|
636
|
+
data_item = self.output_schema(**{field: value})
|
|
637
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
638
|
+
|
|
639
|
+
# create RecordOpStats object
|
|
640
|
+
record_op_stats = RecordOpStats(
|
|
641
|
+
record_id=dr._id,
|
|
642
|
+
record_parent_ids=dr._parent_ids,
|
|
643
|
+
record_source_indices=dr._source_indices,
|
|
644
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
645
|
+
full_op_id=self.get_full_op_id(),
|
|
646
|
+
logical_op_id=self.logical_op_id,
|
|
647
|
+
op_name=self.op_name(),
|
|
648
|
+
time_per_record=time.time() - start_time,
|
|
649
|
+
cost_per_record=generation_stats.cost_per_record,
|
|
650
|
+
model_name=self.get_model_name(),
|
|
651
|
+
answer={field: value},
|
|
652
|
+
input_fields=input_fields,
|
|
653
|
+
generated_fields=fields_to_generate,
|
|
654
|
+
total_input_tokens=generation_stats.total_input_tokens,
|
|
655
|
+
total_output_tokens=generation_stats.total_output_tokens,
|
|
656
|
+
total_input_cost=generation_stats.total_input_cost,
|
|
657
|
+
total_output_cost=generation_stats.total_output_cost,
|
|
658
|
+
llm_call_duration_secs=generation_stats.llm_call_duration_secs,
|
|
659
|
+
fn_call_duration_secs=generation_stats.fn_call_duration_secs,
|
|
660
|
+
total_llm_calls=generation_stats.total_llm_calls,
|
|
661
|
+
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
662
|
+
image_operation=self.is_image_op(),
|
|
663
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
@@ -320,7 +320,7 @@ class LLMConvert(ConvertOp):
|
|
|
320
320
|
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
321
321
|
|
|
322
322
|
# get est. of conversion time per record from model card;
|
|
323
|
-
model_name = self.model.value
|
|
323
|
+
model_name = self.model.value
|
|
324
324
|
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
325
325
|
|
|
326
326
|
# get est. of conversion cost (in USD) per record from model card
|