palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,524 @@
1
+ import functools
2
+ import inspect
3
+ import os
4
+ import time
5
+ from typing import Any
6
+
7
+ # from mem0 import Memory
8
+ from smolagents import CodeAgent, LiteLLMModel, tool
9
+
10
+ # from palimpzest.agents.search_agents import DataDiscoveryAgent, SearchManagerAgent
11
+ from palimpzest.core.data.context import Context
12
+ from palimpzest.core.data.context_manager import ContextManager
13
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
14
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
15
+ from palimpzest.query.operators.physical import PhysicalOperator
16
+
17
+
18
+ def make_tool(bound_method):
19
+ # Get the original function and bound instance
20
+ func = bound_method.__func__
21
+ instance = bound_method.__self__
22
+
23
+ # Get the signature and remove 'self'
24
+ sig = inspect.signature(func)
25
+ params = list(sig.parameters.values())[1:] # skip 'self'
26
+ new_sig = inspect.Signature(parameters=params, return_annotation=sig.return_annotation)
27
+
28
+ # Create a wrapper function dynamically
29
+ @functools.wraps(func)
30
+ def wrapper(*args, **kwargs):
31
+ return func(instance, *args, **kwargs)
32
+
33
+ # Update the __signature__ to reflect the new one without 'self'
34
+ wrapper.__signature__ = new_sig
35
+
36
+ return wrapper
37
+
38
+
39
+ class SmolAgentsSearch(PhysicalOperator):
40
+ """
41
+ Physical operator for searching with Smol Agents.
42
+ """
43
+ def __init__(self, context_id: str, search_query: str, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+ self.context_id = context_id
46
+ self.search_query = search_query
47
+ # self.model_id = "anthropic/claude-3-7-sonnet-latest"
48
+ self.model_id = "openai/gpt-4o-mini-2024-07-18"
49
+ # self.model_id = "openai/gpt-4o-2024-08-06"
50
+ api_key = os.getenv("ANTHROPIC_API_KEY") if "anthropic" in self.model_id else os.getenv("OPENAI_API_KEY")
51
+ self.model = LiteLLMModel(model_id=self.model_id, api_key=api_key)
52
+
53
+ def __str__(self):
54
+ op = super().__str__()
55
+ op += f" Context ID: {self.context_id:20s}\n"
56
+ op += f" Search Query: {self.search_query:20s}\n"
57
+ return op
58
+
59
+ def get_id_params(self):
60
+ id_params = super().get_id_params()
61
+ return {
62
+ "context_id": self.context_id,
63
+ "search_query": self.search_query,
64
+ **id_params,
65
+ }
66
+
67
+ def get_op_params(self):
68
+ op_params = super().get_op_params()
69
+ return {
70
+ "context_id": self.context_id,
71
+ "search_query": self.search_query,
72
+ **op_params,
73
+ }
74
+
75
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
76
+ return OperatorCostEstimates(
77
+ cardinality=source_op_cost_estimates.cardinality,
78
+ time_per_record=100,
79
+ cost_per_record=1,
80
+ quality=1.0,
81
+ )
82
+
83
+ def _create_record_set(
84
+ self,
85
+ candidate: DataRecord,
86
+ generation_stats: GenerationStats,
87
+ total_time: float,
88
+ answer: dict[str, Any],
89
+ ) -> DataRecordSet:
90
+ """
91
+ Given an input DataRecord and a determination of whether it passed the filter or not,
92
+ construct the resulting RecordSet.
93
+ """
94
+ # create new DataRecord and set passed_operator attribute
95
+ dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
96
+ for field in self.output_schema.model_fields:
97
+ if field in answer:
98
+ dr[field] = answer[field]
99
+
100
+ # create RecordOpStats object
101
+ record_op_stats = RecordOpStats(
102
+ record_id=dr.id,
103
+ record_parent_ids=dr.parent_ids,
104
+ record_source_indices=dr.source_indices,
105
+ record_state=dr.to_dict(include_bytes=False),
106
+ full_op_id=self.get_full_op_id(),
107
+ logical_op_id=self.logical_op_id,
108
+ op_name=self.op_name(),
109
+ time_per_record=total_time,
110
+ cost_per_record=generation_stats.cost_per_record,
111
+ model_name=self.get_model_name(),
112
+ total_input_tokens=generation_stats.total_input_tokens,
113
+ total_output_tokens=generation_stats.total_output_tokens,
114
+ total_input_cost=generation_stats.total_input_cost,
115
+ total_output_cost=generation_stats.total_output_cost,
116
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
117
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
118
+ total_llm_calls=generation_stats.total_llm_calls,
119
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
120
+ answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
121
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
122
+ )
123
+
124
+ return DataRecordSet([dr], [record_op_stats])
125
+
126
+ def __call__(self, candidate: DataRecord) -> Any:
127
+ start_time = time.time()
128
+
129
+ # get the input context object and its tools
130
+ input_context: Context = candidate.context
131
+ description = input_context.description
132
+ tools = [tool(make_tool(f)) for f in input_context.tools]
133
+
134
+ # # construct the full search query
135
+ # full_query = f"Please execute the following search query. Output a **detailed** description of (1) which data you look at, and (2) what you find in that data. Avoid making overly broad statements such as \"What you're searching for is not present in the dataset\". Instead, make more precise statments like \"What you're searching for is not present in files A.txt, B.txt, and C.txt, but may be present elsewhere\".\n\nQUERY: {self.search_query}"
136
+
137
+ # perform the computation
138
+ instructions = f"\n\nHere is a description of the Context whose data you will be working with, as well as any previously computed results:\n\n{description}"
139
+ agent = CodeAgent(
140
+ tools=tools,
141
+ model=self.model,
142
+ add_base_tools=False,
143
+ instructions=instructions,
144
+ return_full_result=True,
145
+ additional_authorized_imports=["pandas", "io", "os"],
146
+ )
147
+ result = agent.run(self.search_query)
148
+ # NOTE: you can see the system prompt with `agent.memory.system_prompt.system_prompt`
149
+ # full_steps = agent.memory.get_full_steps()
150
+
151
+ # compute generation stats
152
+ response = result.output
153
+ input_tokens = result.token_usage.input_tokens
154
+ output_tokens = result.token_usage.output_tokens
155
+ cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (2.5 / 1e6) #
156
+ cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (10.0 / 1e6) #
157
+ input_cost = input_tokens * cost_per_input_token
158
+ output_cost = output_tokens * cost_per_output_token
159
+ generation_stats = GenerationStats(
160
+ model_name=self.model_id,
161
+ total_input_tokens=input_tokens,
162
+ total_output_tokens=output_tokens,
163
+ total_input_cost=input_cost,
164
+ total_output_cost=output_cost,
165
+ cost_per_record=input_cost + output_cost,
166
+ llm_call_duration_secs=time.time() - start_time,
167
+ )
168
+
169
+ # update the description of the Context to include the search result
170
+ new_description = f"RESULT: {response}\n\n"
171
+ cm = ContextManager()
172
+ cm.update_context(id=self.context_id, description=new_description)
173
+
174
+ # create and return record set
175
+ field_answers = {
176
+ "context": cm.get_context(id=self.context_id),
177
+ }
178
+ record_set = self._create_record_set(
179
+ candidate,
180
+ generation_stats,
181
+ time.time() - start_time,
182
+ field_answers,
183
+ )
184
+
185
+ return record_set
186
+
187
+
188
+ # class SmolAgentsManagedSearch(PhysicalOperator):
189
+ # """
190
+ # Physical operator for searching with Smol Agents using an Orchestrator and a Data Discovery Agent.
191
+ # """
192
+ # def __init__(self, context_id: str, search_query: str, *args, **kwargs):
193
+ # super().__init__(*args, **kwargs)
194
+ # self.context_id = context_id
195
+ # self.search_query = search_query
196
+ # # self.model_id = "anthropic/claude-3-7-sonnet-latest"
197
+ # self.model_id = "openai/gpt-4o-mini-2024-07-18"
198
+ # # self.model_id = "o1"
199
+ # model_params = {
200
+ # "model_id": self.model_id,
201
+ # "custom_role_conversions": {"tool-call": "assistant", "tool-response": "user"},
202
+ # "max_completion_tokens": 8192,
203
+ # }
204
+ # if self.model_id == "o1":
205
+ # model_params["reasoning_effort"] = "high"
206
+ # self.model = LiteLLMModel(**model_params)
207
+ # self.text_limit = 100000
208
+ # self.memory = Memory()
209
+
210
+ # def __str__(self):
211
+ # op = super().__str__()
212
+ # op += f" Context ID: {self.context_id:20s}\n"
213
+ # op += f" Search Query: {self.search_query:20s}\n"
214
+ # return op
215
+
216
+ # def get_id_params(self):
217
+ # id_params = super().get_id_params()
218
+ # return {
219
+ # "context_id": self.context_id,
220
+ # "search_query": self.search_query,
221
+ # **id_params,
222
+ # }
223
+
224
+ # def get_op_params(self):
225
+ # op_params = super().get_op_params()
226
+ # return {
227
+ # "context_id": self.context_id,
228
+ # "search_query": self.search_query,
229
+ # **op_params,
230
+ # }
231
+
232
+ # def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
233
+ # return OperatorCostEstimates(
234
+ # cardinality=source_op_cost_estimates.cardinality,
235
+ # time_per_record=100,
236
+ # cost_per_record=1,
237
+ # quality=1.0,
238
+ # )
239
+
240
+ # def _create_record_set(
241
+ # self,
242
+ # candidate: DataRecord,
243
+ # generation_stats: GenerationStats,
244
+ # total_time: float,
245
+ # answer: dict[str, Any],
246
+ # ) -> DataRecordSet:
247
+ # """
248
+ # Given an input DataRecord and a determination of whether it passed the filter or not,
249
+ # construct the resulting RecordSet.
250
+ # """
251
+ # # create new DataRecord and set passed_operator attribute
252
+ # dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
253
+ # for field in self.output_schema.model_fields:
254
+ # if field in answer:
255
+ # dr[field] = answer[field]
256
+
257
+ # # create RecordOpStats object
258
+ # record_op_stats = RecordOpStats(
259
+ # record_id=dr.id,
260
+ # record_parent_ids=dr.parent_ids,
261
+ # record_source_indices=dr.source_indices,
262
+ # record_state=dr.to_dict(include_bytes=False),
263
+ # full_op_id=self.get_full_op_id(),
264
+ # logical_op_id=self.logical_op_id,
265
+ # op_name=self.op_name(),
266
+ # time_per_record=total_time,
267
+ # cost_per_record=generation_stats.cost_per_record,
268
+ # model_name=self.get_model_name(),
269
+ # total_input_tokens=generation_stats.total_input_tokens,
270
+ # total_output_tokens=generation_stats.total_output_tokens,
271
+ # total_input_cost=generation_stats.total_input_cost,
272
+ # total_output_cost=generation_stats.total_output_cost,
273
+ # llm_call_duration_secs=generation_stats.llm_call_duration_secs,
274
+ # fn_call_duration_secs=generation_stats.fn_call_duration_secs,
275
+ # total_llm_calls=generation_stats.total_llm_calls,
276
+ # total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
277
+ # answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
278
+ # op_details={k: str(v) for k, v in self.get_id_params().items()},
279
+ # )
280
+
281
+ # return DataRecordSet([dr], [record_op_stats])
282
+
283
+ # def __call__(self, candidate: DataRecord) -> Any:
284
+ # start_time = time.time()
285
+
286
+ # # get the input context object and its tools
287
+ # input_context: Context = candidate.context
288
+ # description = input_context.description
289
+ # tools = [tool(make_tool(f)) for f in input_context.tools]
290
+
291
+ # # create a memory tool for accessing past searches
292
+ # @tool
293
+ # def tool_search_history(query: str) -> str:
294
+ # """
295
+ # This tool enables the agent to search through its history of execution in previous sessions.
296
+ # Thus, the agent can learn more about what it has done in the past by invoking this tool with
297
+ # a query describing what past interactions the agent might be curious about.
298
+
299
+ # Args:
300
+ # query (str): A description of what the agent wishes to search for in its execution history.
301
+
302
+ # Returns:
303
+ # str: A summary of the agent execution history which is relevant to the query.
304
+ # """
305
+ # memories = self.memory.search(query=query, user_id="data_discovery_agent")
306
+ # memory_str = ""
307
+ # for idx, memory in enumerate(memories):
308
+ # memory_str += f"MEMORY {idx+1}: {memory['memory']}"
309
+ # return memory_str
310
+
311
+ # # tools.append(tool_search_history)
312
+ # data_discovery_agent = CodeAgent(
313
+ # model=self.model,
314
+ # tools=tools,
315
+ # max_steps=20,
316
+ # verbosity_level=2,
317
+ # planning_interval=4,
318
+ # name="data_discovery_agent",
319
+ # description="""A team member that will search a data repository to find files which help to answer your question.
320
+ # Ask him for all your questions that require searching a repository of relevant data.
321
+ # Provide him as much context as possible, in particular if you need to search on a specific timeframe!
322
+ # And don't hesitate to provide him with a complex search task, like finding a difference between two files.
323
+ # Your request must be a real sentence, not a keyword search! Like "Find me this information (...)" rather than a few keywords.
324
+ # """,
325
+ # provide_run_summary=True,
326
+ # )
327
+ # data_discovery_agent.prompt_templates["managed_agent"]["task"] += f"""\n\nHere is a description of the context you will be working with: {description}\n\nSearch as many files as possible before returning your final answer.\n\nAdditionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""
328
+
329
+ # manager_agent = CodeAgent(
330
+ # model=self.model,
331
+ # tools=tools,
332
+ # max_steps=12,
333
+ # verbosity_level=2,
334
+ # additional_authorized_imports=["*"],
335
+ # planning_interval=4,
336
+ # managed_agents=[data_discovery_agent],
337
+ # return_full_result=True,
338
+ # )
339
+
340
+ # # TODO: improve context descriptions and add memory from there; expand to multi-modal benchmark(s)
341
+ # # perform the computation
342
+ # result = manager_agent.run(self.search_query)
343
+
344
+ # # compute generation stats
345
+ # response = result.output
346
+ # input_tokens = result.token_usage.input_tokens
347
+ # output_tokens = result.token_usage.output_tokens
348
+ # cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (15.0 / 1e6)
349
+ # cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (60.0 / 1e6)
350
+ # input_cost = input_tokens * cost_per_input_token
351
+ # output_cost = output_tokens * cost_per_output_token
352
+ # generation_stats = GenerationStats(
353
+ # model_name=self.model_id,
354
+ # total_input_tokens=input_tokens,
355
+ # total_output_tokens=output_tokens,
356
+ # total_input_cost=input_cost,
357
+ # total_output_cost=output_cost,
358
+ # cost_per_record=input_cost + output_cost,
359
+ # llm_call_duration_secs=time.time() - start_time,
360
+ # )
361
+
362
+ # # update the description of the Context to include the search result
363
+ # new_description = f"RESULT: {response}\n\n"
364
+ # cm = ContextManager()
365
+ # cm.update_context(id=self.context_id, description=new_description)
366
+
367
+ # # create and return record set
368
+ # field_answers = {
369
+ # "context": cm.get_context(id=self.context_id),
370
+ # }
371
+ # record_set = self._create_record_set(
372
+ # candidate,
373
+ # generation_stats,
374
+ # time.time() - start_time,
375
+ # field_answers,
376
+ # )
377
+
378
+ # return record_set
379
+
380
+
381
+ # class SmolAgentsCustomManagedSearch(PhysicalOperator):
382
+ # """
383
+ # Physical operator for searching with Smol Agents using an Orchestrator and a Data Discovery Agent.
384
+ # """
385
+ # def __init__(self, context_id: str, search_query: str, *args, **kwargs):
386
+ # super().__init__(*args, **kwargs)
387
+ # self.context_id = context_id
388
+ # self.search_query = search_query
389
+ # # self.model_id = "anthropic/claude-3-7-sonnet-latest"
390
+ # self.model_id = "openai/gpt-4o-mini-2024-07-18"
391
+ # # self.model_id = "o1"
392
+ # model_params = {
393
+ # "model_id": self.model_id,
394
+ # "custom_role_conversions": {"tool-call": "assistant", "tool-response": "user"},
395
+ # "max_completion_tokens": 8192,
396
+ # }
397
+ # if self.model_id == "o1":
398
+ # model_params["reasoning_effort"] = "high"
399
+ # self.model = LiteLLMModel(**model_params)
400
+ # self.text_limit = 100000
401
+
402
+ # def __str__(self):
403
+ # op = super().__str__()
404
+ # op += f" Context ID: {self.context_id:20s}\n"
405
+ # op += f" Search Query: {self.search_query:20s}\n"
406
+ # return op
407
+
408
+ # def get_id_params(self):
409
+ # id_params = super().get_id_params()
410
+ # return {
411
+ # "context_id": self.context_id,
412
+ # "search_query": self.search_query,
413
+ # **id_params,
414
+ # }
415
+
416
+ # def get_op_params(self):
417
+ # op_params = super().get_op_params()
418
+ # return {
419
+ # "context_id": self.context_id,
420
+ # "search_query": self.search_query,
421
+ # **op_params,
422
+ # }
423
+
424
+ # def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
425
+ # return OperatorCostEstimates(
426
+ # cardinality=source_op_cost_estimates.cardinality,
427
+ # time_per_record=100,
428
+ # cost_per_record=1,
429
+ # quality=1.0,
430
+ # )
431
+
432
+ # def _create_record_set(
433
+ # self,
434
+ # candidate: DataRecord,
435
+ # generation_stats: GenerationStats,
436
+ # total_time: float,
437
+ # answer: dict[str, Any],
438
+ # ) -> DataRecordSet:
439
+ # """
440
+ # Given an input DataRecord and a determination of whether it passed the filter or not,
441
+ # construct the resulting RecordSet.
442
+ # """
443
+ # # create new DataRecord and set passed_operator attribute
444
+ # dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
445
+ # for field in self.output_schema.model_fields:
446
+ # if field in answer:
447
+ # dr[field] = answer[field]
448
+
449
+ # # create RecordOpStats object
450
+ # record_op_stats = RecordOpStats(
451
+ # record_id=dr.id,
452
+ # record_parent_ids=dr.parent_ids,
453
+ # record_source_indices=dr.source_indices,
454
+ # record_state=dr.to_dict(include_bytes=False),
455
+ # full_op_id=self.get_full_op_id(),
456
+ # logical_op_id=self.logical_op_id,
457
+ # op_name=self.op_name(),
458
+ # time_per_record=total_time,
459
+ # cost_per_record=generation_stats.cost_per_record,
460
+ # model_name=self.get_model_name(),
461
+ # total_input_tokens=generation_stats.total_input_tokens,
462
+ # total_output_tokens=generation_stats.total_output_tokens,
463
+ # total_input_cost=generation_stats.total_input_cost,
464
+ # total_output_cost=generation_stats.total_output_cost,
465
+ # llm_call_duration_secs=generation_stats.llm_call_duration_secs,
466
+ # fn_call_duration_secs=generation_stats.fn_call_duration_secs,
467
+ # total_llm_calls=generation_stats.total_llm_calls,
468
+ # total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
469
+ # answer={k: v.description if isinstance(v, Context) else v for k, v in answer.items()},
470
+ # op_details={k: str(v) for k, v in self.get_id_params().items()},
471
+ # )
472
+
473
+ # return DataRecordSet([dr], [record_op_stats])
474
+
475
+ # def __call__(self, candidate: DataRecord) -> Any:
476
+ # start_time = time.time()
477
+
478
+ # # get the input context object and its tools
479
+ # input_context: Context = candidate.context
480
+ # description = input_context.description
481
+ # tools = [tool(make_tool(f)) for f in input_context.tools]
482
+
483
+ # # TODO: add semantic operators to tools
484
+ # data_discovery_agent = DataDiscoveryAgent(self.context_id, description, model=self.model, tools=tools)
485
+ # search_manager_agent = SearchManagerAgent(self.context_id, description, model=self.model, tools=tools, managed_agents=[data_discovery_agent])
486
+
487
+ # # perform the computation
488
+ # result = search_manager_agent.run(self.search_query)
489
+
490
+ # # compute generation stats
491
+ # response = result.output
492
+ # input_tokens = result.token_usage.input_tokens
493
+ # output_tokens = result.token_usage.output_tokens
494
+ # cost_per_input_token = (3.0 / 1e6) if "anthropic" in self.model_id else (0.15 / 1e6) # (15.0 / 1e6)
495
+ # cost_per_output_token = (15.0 / 1e6) if "anthropic" in self.model_id else (0.6 / 1e6) # (60.0 / 1e6)
496
+ # input_cost = input_tokens * cost_per_input_token
497
+ # output_cost = output_tokens * cost_per_output_token
498
+ # generation_stats = GenerationStats(
499
+ # model_name=self.model_id,
500
+ # total_input_tokens=input_tokens,
501
+ # total_output_tokens=output_tokens,
502
+ # total_input_cost=input_cost,
503
+ # total_output_cost=output_cost,
504
+ # cost_per_record=input_cost + output_cost,
505
+ # llm_call_duration_secs=time.time() - start_time,
506
+ # )
507
+
508
+ # # update the description of the Context to include the search result
509
+ # new_description = f"RESULT: {response}\n\n"
510
+ # cm = ContextManager()
511
+ # cm.update_context(id=self.context_id, description=new_description)
512
+
513
+ # # create and return record set
514
+ # field_answers = {
515
+ # "context": cm.get_context(id=self.context_id),
516
+ # }
517
+ # record_set = self._create_record_set(
518
+ # candidate,
519
+ # generation_stats,
520
+ # time.time() - start_time,
521
+ # field_answers,
522
+ # )
523
+
524
+ # return record_set
@@ -2,26 +2,28 @@ from __future__ import annotations
2
2
 
3
3
  import math
4
4
 
5
+ from pydantic.fields import FieldInfo
6
+
5
7
  from palimpzest.constants import (
6
8
  MODEL_CARDS,
7
9
  NAIVE_EST_NUM_INPUT_TOKENS,
8
10
  NAIVE_EST_NUM_OUTPUT_TOKENS,
9
11
  PromptStrategy,
10
12
  )
11
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
12
13
  from palimpzest.core.elements.records import DataRecord
13
- from palimpzest.core.lib.fields import Field, StringField
14
- from palimpzest.query.generators.generators import generator_factory
14
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
15
+ from palimpzest.query.generators.generators import Generator
15
16
  from palimpzest.query.operators.convert import LLMConvert
16
17
 
17
18
 
18
19
  class SplitConvert(LLMConvert):
19
20
  def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
21
+ kwargs["prompt_strategy"] = None
20
22
  super().__init__(*args, **kwargs)
21
23
  self.num_chunks = num_chunks
22
24
  self.min_size_to_chunk = min_size_to_chunk
23
- self.split_generator = generator_factory(self.model, PromptStrategy.SPLIT_PROPOSER, self.cardinality, self.verbose)
24
- self.split_merge_generator = generator_factory(self.model, PromptStrategy.SPLIT_MERGER, self.cardinality, self.verbose)
25
+ self.split_generator = Generator(self.model, PromptStrategy.SPLIT_PROPOSER, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
26
+ self.split_merge_generator = Generator(self.model, PromptStrategy.SPLIT_MERGER, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
25
27
 
26
28
  # crude adjustment factor for naive estimation in no-sentinel setting
27
29
  self.naive_quality_adjustment = 0.6
@@ -103,8 +105,8 @@ class SplitConvert(LLMConvert):
103
105
  content = candidate[field_name]
104
106
 
105
107
  # do not chunk this field if it is not a string or a list of strings
106
- is_string_field = isinstance(field, StringField)
107
- is_list_string_field = hasattr(field, "element_type") and isinstance(field.element_type, StringField)
108
+ is_string_field = field.annotation in [str, str | None]
109
+ is_list_string_field = field.annotation in [list[str], list[str] | None]
108
110
  if not (is_string_field or is_list_string_field):
109
111
  field_name_to_chunked_content[field_name] = [content]
110
112
  continue
@@ -136,7 +138,7 @@ class SplitConvert(LLMConvert):
136
138
 
137
139
  return candidates
138
140
 
139
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
141
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
140
142
  # get the set of input fields to use for the convert operation
141
143
  input_fields = self.get_input_fields()
142
144
 
@@ -1,15 +1,10 @@
1
+ from palimpzest.query.optimizer.rules import AddContextsBeforeComputeRule as _AddContextsBeforeComputeRule
1
2
  from palimpzest.query.optimizer.rules import (
2
3
  AggregateRule as _AggregateRule,
3
4
  )
4
5
  from palimpzest.query.optimizer.rules import (
5
6
  BasicSubstitutionRule as _BasicSubstitutionRule,
6
7
  )
7
- from palimpzest.query.optimizer.rules import (
8
- CodeSynthesisConvertRule as _CodeSynthesisConvertRule,
9
- )
10
- from palimpzest.query.optimizer.rules import (
11
- CodeSynthesisConvertSingleRule as _CodeSynthesisConvertSingleRule,
12
- )
13
8
  from palimpzest.query.optimizer.rules import (
14
9
  CriticAndRefineConvertRule as _CriticAndRefineConvertRule,
15
10
  )
@@ -22,6 +17,9 @@ from palimpzest.query.optimizer.rules import (
22
17
  from palimpzest.query.optimizer.rules import (
23
18
  LLMFilterRule as _LLMFilterRule,
24
19
  )
20
+ from palimpzest.query.optimizer.rules import (
21
+ LLMJoinRule as _LLMJoinRule,
22
+ )
25
23
  from palimpzest.query.optimizer.rules import (
26
24
  MixtureOfAgentsConvertRule as _MixtureOfAgentsConvertRule,
27
25
  )
@@ -51,14 +49,14 @@ from palimpzest.query.optimizer.rules import (
51
49
  )
52
50
 
53
51
  ALL_RULES = [
52
+ _AddContextsBeforeComputeRule,
54
53
  _AggregateRule,
55
54
  _BasicSubstitutionRule,
56
- _CodeSynthesisConvertRule,
57
- _CodeSynthesisConvertSingleRule,
58
55
  _CriticAndRefineConvertRule,
59
56
  _ImplementationRule,
60
57
  _LLMConvertBondedRule,
61
58
  _LLMFilterRule,
59
+ _LLMJoinRule,
62
60
  _MixtureOfAgentsConvertRule,
63
61
  _NonLLMConvertRule,
64
62
  _NonLLMFilterRule,
@@ -74,7 +72,7 @@ IMPLEMENTATION_RULES = [
74
72
  rule
75
73
  for rule in ALL_RULES
76
74
  if issubclass(rule, _ImplementationRule)
77
- and rule not in [_CodeSynthesisConvertRule, _ImplementationRule]
75
+ and rule not in [_ImplementationRule]
78
76
  ]
79
77
 
80
78
  TRANSFORMATION_RULES = [