flock-core 0.4.0b3__py3-none-any.whl → 0.4.0b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flock-core might be problematic. Click here for more details.
- flock/__init__.py +12 -0
- flock/cli/config.py +8 -0
- flock/cli/constants.py +11 -0
- flock/cli/create_flock.py +18 -6
- flock/cli/execute_flock.py +397 -1
- flock/cli/loaded_flock_cli.py +19 -4
- flock/cli/runner.py +41 -0
- flock/config.py +5 -0
- flock/core/api/endpoints.py +102 -2
- flock/core/api/main.py +214 -0
- flock/core/api/models.py +63 -0
- flock/core/api/run_store.py +153 -1
- flock/core/api/runner.py +38 -0
- flock/core/context/context_vars.py +1 -0
- flock/core/evaluation/utils.py +312 -0
- flock/core/execution/batch_executor.py +325 -0
- flock/core/execution/evaluation_executor.py +438 -0
- flock/core/flock.py +325 -1146
- flock/core/serialization/flock_serializer.py +717 -0
- flock/core/tools/azure_tools.py +2 -1
- flock/core/tools/basic_tools.py +1 -1
- flock/core/util/loader.py +59 -0
- flock/modules/output/output_module.py +43 -8
- {flock_core-0.4.0b3.dist-info → flock_core-0.4.0b5.dist-info}/METADATA +4 -1
- {flock_core-0.4.0b3.dist-info → flock_core-0.4.0b5.dist-info}/RECORD +28 -20
- {flock_core-0.4.0b3.dist-info → flock_core-0.4.0b5.dist-info}/WHEEL +0 -0
- {flock_core-0.4.0b3.dist-info → flock_core-0.4.0b5.dist-info}/entry_points.txt +0 -0
- {flock_core-0.4.0b3.dist-info → flock_core-0.4.0b5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# src/flock/core/util/evaluation_helpers.py
|
|
2
|
+
import inspect
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Union
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from box import Box
|
|
10
|
+
from datasets import get_dataset_config_names, load_dataset
|
|
11
|
+
|
|
12
|
+
from flock.core.flock_agent import FlockAgent
|
|
13
|
+
from flock.core.flock_evaluator import FlockEvaluator
|
|
14
|
+
from flock.core.logging.logging import get_logger
|
|
15
|
+
|
|
16
|
+
# Potentially import metrics libraries like rouge_score, nltk, sentence_transformers
|
|
17
|
+
|
|
18
|
+
logger_helpers = get_logger("util.evaluation")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_and_merge_all_configs(dataset_name: str) -> pd.DataFrame:
|
|
22
|
+
all_configs = get_dataset_config_names(dataset_name)
|
|
23
|
+
all_dfs = []
|
|
24
|
+
|
|
25
|
+
for config in all_configs:
|
|
26
|
+
dataset_dict = load_dataset(dataset_name, config)
|
|
27
|
+
for split_name, split_dataset in dataset_dict.items():
|
|
28
|
+
df = split_dataset.to_pandas()
|
|
29
|
+
df["config"] = config
|
|
30
|
+
df["split"] = split_name
|
|
31
|
+
all_dfs.append(df)
|
|
32
|
+
|
|
33
|
+
merged_df = pd.concat(all_dfs, ignore_index=True)
|
|
34
|
+
return merged_df
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def normalize_dataset(dataset: Any) -> pd.DataFrame:
|
|
38
|
+
"""Converts various dataset formats into a pandas DataFrame."""
|
|
39
|
+
if isinstance(dataset, pd.DataFrame):
|
|
40
|
+
return dataset.copy()
|
|
41
|
+
elif isinstance(dataset, str | Path):
|
|
42
|
+
path = Path(dataset)
|
|
43
|
+
if not path.exists():
|
|
44
|
+
try:
|
|
45
|
+
return load_and_merge_all_configs(dataset)
|
|
46
|
+
except Exception as e:
|
|
47
|
+
raise FileNotFoundError(
|
|
48
|
+
f"Dataset file not found: {path}"
|
|
49
|
+
) from e
|
|
50
|
+
if path.suffix.lower() == ".csv":
|
|
51
|
+
return pd.read_csv(path)
|
|
52
|
+
# Add support for json, jsonl etc. if needed
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Unsupported file type for dataset: {path.suffix}"
|
|
56
|
+
)
|
|
57
|
+
elif isinstance(dataset, list):
|
|
58
|
+
if not dataset or not isinstance(dataset[0], dict):
|
|
59
|
+
raise ValueError("Dataset list must contain dictionaries.")
|
|
60
|
+
return pd.DataFrame(dataset)
|
|
61
|
+
elif "datasets" in sys.modules and isinstance(
|
|
62
|
+
dataset, sys.modules["datasets"].Dataset
|
|
63
|
+
):
|
|
64
|
+
# Requires 'datasets' library to be installed
|
|
65
|
+
return dataset.to_pandas()
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError(f"Unsupported dataset type: {type(dataset)}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def extract_value_by_dot_notation(data: dict | Box, key: str) -> Any:
|
|
71
|
+
"""Retrieves a value from a nested dictionary or Box object using dot notation."""
|
|
72
|
+
if not key:
|
|
73
|
+
return None
|
|
74
|
+
keys = key.split(".")
|
|
75
|
+
value = data
|
|
76
|
+
try:
|
|
77
|
+
for k in keys:
|
|
78
|
+
if isinstance(value, (dict, Box)):
|
|
79
|
+
value = value.get(k)
|
|
80
|
+
# Add list index handling if needed: e.g., 'results[0].field'
|
|
81
|
+
# elif isinstance(value, list) and k.isdigit():
|
|
82
|
+
# value = value[int(k)]
|
|
83
|
+
else:
|
|
84
|
+
return None # Cannot traverse further
|
|
85
|
+
if value is None:
|
|
86
|
+
return None # Key not found at this level
|
|
87
|
+
return value
|
|
88
|
+
except (KeyError, IndexError, AttributeError):
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def calculate_evaluation_metrics(
|
|
93
|
+
metrics: list[Union[str, Callable, "FlockAgent", "FlockEvaluator"]],
|
|
94
|
+
metric_configs: dict[str, dict[str, Any]],
|
|
95
|
+
predicted_answers: dict[str, Any],
|
|
96
|
+
expected_answers: dict[str, Any],
|
|
97
|
+
agent_inputs: dict[str, Any], # For context
|
|
98
|
+
agent_output: Any, # For context
|
|
99
|
+
) -> dict[str, Any]:
|
|
100
|
+
"""Calculates all specified metrics for a single evaluation item."""
|
|
101
|
+
results = {}
|
|
102
|
+
for metric in metrics:
|
|
103
|
+
metric_name = ""
|
|
104
|
+
metric_result = None
|
|
105
|
+
try:
|
|
106
|
+
if isinstance(metric, str):
|
|
107
|
+
metric_name = metric
|
|
108
|
+
# Find predicted/expected values relevant to this metric string
|
|
109
|
+
# Simple case: metric name matches an answer_mapping key
|
|
110
|
+
if (
|
|
111
|
+
metric_name in predicted_answers
|
|
112
|
+
and metric_name in expected_answers
|
|
113
|
+
):
|
|
114
|
+
predicted = predicted_answers[metric_name]
|
|
115
|
+
expected = expected_answers[metric_name]
|
|
116
|
+
metric_func = _get_metric_function(metric_name)
|
|
117
|
+
config = metric_configs.get(metric_name, {})
|
|
118
|
+
metric_result = metric_func(predicted, expected, **config)
|
|
119
|
+
else:
|
|
120
|
+
logger_helpers.warning(
|
|
121
|
+
f"Could not find matching predicted/expected values for metric '{metric_name}' based on answer_mapping keys."
|
|
122
|
+
)
|
|
123
|
+
metric_result = None # Or some error indicator
|
|
124
|
+
|
|
125
|
+
elif isinstance(metric, Callable):
|
|
126
|
+
metric_name = getattr(metric, "__name__", "custom_function")
|
|
127
|
+
# Custom functions might need specific predicted/expected pairs, or all of them
|
|
128
|
+
# Let's pass all for flexibility, user function needs to handle it
|
|
129
|
+
config = metric_configs.get(metric_name, {})
|
|
130
|
+
# Allow passing context if function signature supports it
|
|
131
|
+
sig = inspect.signature(metric)
|
|
132
|
+
call_kwargs = config.copy()
|
|
133
|
+
if "agent_inputs" in sig.parameters:
|
|
134
|
+
call_kwargs["agent_inputs"] = agent_inputs
|
|
135
|
+
if "agent_output" in sig.parameters:
|
|
136
|
+
call_kwargs["agent_output"] = agent_output
|
|
137
|
+
|
|
138
|
+
metric_result = metric(
|
|
139
|
+
predicted_answers, expected_answers, **call_kwargs
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# --- Placeholder for Agent/Evaluator based metrics ---
|
|
143
|
+
elif "FlockAgent" in str(
|
|
144
|
+
type(metric)
|
|
145
|
+
): # Avoid hard import if possible
|
|
146
|
+
metric_name = getattr(metric, "name", "judge_agent")
|
|
147
|
+
config = metric_configs.get(metric_name, {})
|
|
148
|
+
# Requires running the judge agent - needs async context
|
|
149
|
+
# metric_result = asyncio.run(_run_judge_agent(metric, predicted_answers, expected_answers, config))
|
|
150
|
+
logger_helpers.warning(
|
|
151
|
+
f"Agent-based metric '{metric_name}' execution not implemented in this sketch."
|
|
152
|
+
)
|
|
153
|
+
metric_result = "[Agent Judge Not Implemented]"
|
|
154
|
+
|
|
155
|
+
elif "FlockEvaluator" in str(
|
|
156
|
+
type(metric)
|
|
157
|
+
): # Avoid hard import if possible
|
|
158
|
+
metric_name = getattr(metric, "name", "judge_evaluator")
|
|
159
|
+
config = metric_configs.get(metric_name, {})
|
|
160
|
+
# Requires running the evaluator - needs async context
|
|
161
|
+
# metric_result = asyncio.run(_run_judge_evaluator(metric, predicted_answers, expected_answers, config))
|
|
162
|
+
logger_helpers.warning(
|
|
163
|
+
f"Evaluator-based metric '{metric_name}' execution not implemented in this sketch."
|
|
164
|
+
)
|
|
165
|
+
metric_result = "[Evaluator Judge Not Implemented]"
|
|
166
|
+
# --- End Placeholder ---
|
|
167
|
+
|
|
168
|
+
else:
|
|
169
|
+
logger_helpers.warning(
|
|
170
|
+
f"Unsupported metric type: {type(metric)}"
|
|
171
|
+
)
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Store result - handle dict results from metrics
|
|
175
|
+
if isinstance(metric_result, dict):
|
|
176
|
+
for sub_key, sub_value in metric_result.items():
|
|
177
|
+
results[f"{metric_name}_{sub_key}"] = sub_value
|
|
178
|
+
else:
|
|
179
|
+
results[metric_name] = metric_result
|
|
180
|
+
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger_helpers.error(
|
|
183
|
+
f"Error calculating metric '{metric_name}': {e}"
|
|
184
|
+
)
|
|
185
|
+
results[metric_name] = f"[Error: {e}]"
|
|
186
|
+
|
|
187
|
+
return results
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _get_metric_function(metric_name: str) -> Callable:
|
|
191
|
+
"""Maps metric names to their implementation functions."""
|
|
192
|
+
# Lazy load metric libraries
|
|
193
|
+
if metric_name == "exact_match":
|
|
194
|
+
return lambda pred, act, **kw: str(pred).strip() == str(act).strip()
|
|
195
|
+
elif metric_name == "fuzzy_match":
|
|
196
|
+
try:
|
|
197
|
+
from thefuzz import fuzz
|
|
198
|
+
|
|
199
|
+
return (
|
|
200
|
+
lambda pred, act, threshold=85, **kw: fuzz.ratio(
|
|
201
|
+
str(pred), str(act)
|
|
202
|
+
)
|
|
203
|
+
>= threshold
|
|
204
|
+
)
|
|
205
|
+
except ImportError:
|
|
206
|
+
logger_helpers.warning(
|
|
207
|
+
"fuzzy_match requires 'thefuzz': pip install thefuzz[speedup]"
|
|
208
|
+
)
|
|
209
|
+
return lambda p, a, **kw: None
|
|
210
|
+
elif metric_name.startswith("rouge"): # rouge_1, rouge_2, rouge_l
|
|
211
|
+
try:
|
|
212
|
+
from rouge_score import rouge_scorer
|
|
213
|
+
|
|
214
|
+
scorer = rouge_scorer.RougeScorer(
|
|
215
|
+
[metric_name.replace("_", "")], use_stemmer=True
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def calculate_rouge(pred, act, score_type="fmeasure", **kw):
|
|
219
|
+
scores = scorer.score(str(act), str(pred))
|
|
220
|
+
return (
|
|
221
|
+
scores[metric_name.replace("_", "")]
|
|
222
|
+
._asdict()
|
|
223
|
+
.get(score_type, 0.0)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return calculate_rouge
|
|
227
|
+
except ImportError:
|
|
228
|
+
logger_helpers.warning(
|
|
229
|
+
"rouge requires 'rouge-score': pip install rouge-score"
|
|
230
|
+
)
|
|
231
|
+
return lambda p, a, **kw: None
|
|
232
|
+
elif metric_name == "semantic_similarity":
|
|
233
|
+
try:
|
|
234
|
+
from sentence_transformers import SentenceTransformer, util
|
|
235
|
+
|
|
236
|
+
# Cache the model? Maybe pass it in via config?
|
|
237
|
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
238
|
+
|
|
239
|
+
def calculate_similarity(pred, act, **kw):
|
|
240
|
+
emb1 = model.encode(str(pred), convert_to_tensor=True)
|
|
241
|
+
emb2 = model.encode(str(act), convert_to_tensor=True)
|
|
242
|
+
return util.pytorch_cos_sim(emb1, emb2).item()
|
|
243
|
+
|
|
244
|
+
return calculate_similarity
|
|
245
|
+
except ImportError:
|
|
246
|
+
logger_helpers.warning(
|
|
247
|
+
"semantic_similarity requires 'sentence-transformers': pip install sentence-transformers"
|
|
248
|
+
)
|
|
249
|
+
return lambda p, a, **kw: None
|
|
250
|
+
# Add bleu, f1 etc.
|
|
251
|
+
elif metric_name == "llm_judge":
|
|
252
|
+
# This is handled by checking type in calculate_evaluation_metrics
|
|
253
|
+
# but we need a placeholder callable here if we map by string first
|
|
254
|
+
return lambda p, a, **kw: "[LLM Judge Not Implemented Directly]"
|
|
255
|
+
else:
|
|
256
|
+
raise ValueError(f"Unknown built-in metric: {metric_name}")
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def aggregate_results(results_list: list[dict[str, Any]]) -> dict[str, Any]:
|
|
260
|
+
"""Aggregates evaluation results across all items."""
|
|
261
|
+
summary = {"total_items": len(results_list), "errors": 0}
|
|
262
|
+
metric_values: dict[str, list[float | bool]] = {}
|
|
263
|
+
|
|
264
|
+
for item in results_list:
|
|
265
|
+
if item.get("error"):
|
|
266
|
+
summary["errors"] += 1
|
|
267
|
+
metrics = item.get("metrics", {})
|
|
268
|
+
for name, value in metrics.items():
|
|
269
|
+
if isinstance(
|
|
270
|
+
value, (float, int, bool)
|
|
271
|
+
): # Only aggregate numerics/bools
|
|
272
|
+
if name not in metric_values:
|
|
273
|
+
metric_values[name] = []
|
|
274
|
+
metric_values[name].append(value)
|
|
275
|
+
|
|
276
|
+
summary["metrics_summary"] = {}
|
|
277
|
+
for name, values in metric_values.items():
|
|
278
|
+
if not values:
|
|
279
|
+
continue
|
|
280
|
+
# Calculate different stats based on value type
|
|
281
|
+
if all(isinstance(v, bool) for v in values):
|
|
282
|
+
summary["metrics_summary"][name] = {
|
|
283
|
+
"accuracy": sum(values) / len(values)
|
|
284
|
+
}
|
|
285
|
+
elif all(isinstance(v, (int, float)) for v in values):
|
|
286
|
+
numeric_values = [v for v in values if isinstance(v, (int, float))]
|
|
287
|
+
if numeric_values:
|
|
288
|
+
summary["metrics_summary"][name] = {
|
|
289
|
+
"mean": sum(numeric_values) / len(numeric_values),
|
|
290
|
+
"count": len(numeric_values),
|
|
291
|
+
# Add min, max, stddev if needed
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
return summary
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# --- Placeholder for async judge execution ---
|
|
298
|
+
# Need to run these within the main async context or manage loops carefully
|
|
299
|
+
async def _run_judge_agent(judge_agent, predicted, expected, config):
|
|
300
|
+
# Prepare input for the judge agent based on its signature
|
|
301
|
+
# E.g., judge_input = {"prediction": predicted_value, "reference": expected_value, "criteria": ...}
|
|
302
|
+
# judge_result = await judge_agent.run_async(judge_input)
|
|
303
|
+
# return judge_result # Or extract specific score/judgement
|
|
304
|
+
return "[Agent Judge Not Implemented]"
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
async def _run_judge_evaluator(judge_evaluator, predicted, expected, config):
|
|
308
|
+
# Prepare input for the judge evaluator based on its signature
|
|
309
|
+
# judge_input = {"prediction": predicted_value, "reference": expected_value, **config}
|
|
310
|
+
# judge_result = await judge_evaluator.evaluate(None, judge_input, []) # Agent might not be needed
|
|
311
|
+
# return judge_result # Or extract specific score/judgement
|
|
312
|
+
return "[Evaluator Judge Not Implemented]"
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from box import Box
|
|
6
|
+
from opentelemetry import trace
|
|
7
|
+
from pandas import DataFrame
|
|
8
|
+
from rich.progress import ( # Import Rich Progress
|
|
9
|
+
BarColumn,
|
|
10
|
+
Progress,
|
|
11
|
+
SpinnerColumn,
|
|
12
|
+
TextColumn,
|
|
13
|
+
TimeElapsedColumn,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from flock.config import TELEMETRY
|
|
17
|
+
from flock.core.context.context import FlockContext
|
|
18
|
+
from flock.core.context.context_vars import FLOCK_BATCH_SILENT_MODE
|
|
19
|
+
from flock.core.flock_agent import FlockAgent
|
|
20
|
+
from flock.core.logging.logging import get_logger
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
PANDAS_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
pd = None
|
|
28
|
+
PANDAS_AVAILABLE = False
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from flock.core.flock import Flock
|
|
32
|
+
|
|
33
|
+
logger = get_logger("flock")
|
|
34
|
+
TELEMETRY.setup_tracing() # Setup OpenTelemetry
|
|
35
|
+
tracer = trace.get_tracer(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BatchProcessor:
|
|
39
|
+
def __init__(self, flock_instance: "Flock"):
|
|
40
|
+
self.flock = flock_instance
|
|
41
|
+
|
|
42
|
+
async def run_batch_async(
|
|
43
|
+
self,
|
|
44
|
+
start_agent: FlockAgent | str,
|
|
45
|
+
batch_inputs: list[dict[str, Any]] | DataFrame | str,
|
|
46
|
+
input_mapping: dict[str, str] | None = None,
|
|
47
|
+
static_inputs: dict[str, Any] | None = None,
|
|
48
|
+
parallel: bool = True,
|
|
49
|
+
max_workers: int = 5,
|
|
50
|
+
use_temporal: bool | None = None,
|
|
51
|
+
box_results: bool = True,
|
|
52
|
+
return_errors: bool = False,
|
|
53
|
+
silent_mode: bool = False,
|
|
54
|
+
write_to_csv: str | None = None,
|
|
55
|
+
) -> list[Box | dict | None | Exception]:
|
|
56
|
+
"""Runs the specified agent/workflow for each item in a batch asynchronously.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
start_agent: Agent instance or name to start each run.
|
|
60
|
+
batch_inputs: Input data in one of these forms:
|
|
61
|
+
- List of dictionaries, each representing inputs for one run
|
|
62
|
+
- Pandas DataFrame where each row is inputs for one run
|
|
63
|
+
- String path to a CSV file to load as DataFrame
|
|
64
|
+
input_mapping: Maps DataFrame/CSV column names to agent input keys (required for DataFrame/CSV).
|
|
65
|
+
static_inputs: Dictionary of inputs constant across all batch runs.
|
|
66
|
+
parallel: Whether to run local jobs in parallel (ignored if use_temporal=True).
|
|
67
|
+
max_workers: Max concurrent local workers (used if parallel=True and use_temporal=False).
|
|
68
|
+
use_temporal: Override Flock's 'enable_temporal' setting for this batch.
|
|
69
|
+
box_results: Wrap successful dictionary results in Box objects.
|
|
70
|
+
return_errors: If True, return Exception objects for failed runs instead of raising.
|
|
71
|
+
silent_mode: If True, suppress output and show progress bar instead.
|
|
72
|
+
write_to_csv: Path to save results as CSV file.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List containing results (Box/dict), None (if error and not return_errors),
|
|
76
|
+
or Exception objects (if error and return_errors). Order matches input.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ValueError: For invalid input combinations.
|
|
80
|
+
ImportError: If DataFrame/CSV used without pandas.
|
|
81
|
+
Exception: First exception from a run if return_errors is False.
|
|
82
|
+
"""
|
|
83
|
+
effective_use_temporal = (
|
|
84
|
+
use_temporal if use_temporal is not None else self.enable_temporal
|
|
85
|
+
)
|
|
86
|
+
exec_mode = (
|
|
87
|
+
"Temporal"
|
|
88
|
+
if effective_use_temporal
|
|
89
|
+
else ("Parallel Local" if parallel else "Sequential Local")
|
|
90
|
+
)
|
|
91
|
+
logger.info(
|
|
92
|
+
f"Starting batch run for agent '{start_agent}'. Execution: {exec_mode}, Silent: {silent_mode}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# --- Input Preparation ---
|
|
96
|
+
prepared_batch_inputs: list[dict[str, Any]] = []
|
|
97
|
+
|
|
98
|
+
if input_mapping == {}:
|
|
99
|
+
input_mapping = None
|
|
100
|
+
if static_inputs == {}:
|
|
101
|
+
static_inputs = None
|
|
102
|
+
|
|
103
|
+
if isinstance(batch_inputs, str):
|
|
104
|
+
# Handle CSV file input
|
|
105
|
+
try:
|
|
106
|
+
df = pd.read_csv(batch_inputs)
|
|
107
|
+
logger.debug(
|
|
108
|
+
f"Loaded CSV file with {len(df)} rows: {batch_inputs}"
|
|
109
|
+
)
|
|
110
|
+
batch_inputs = df # Convert to DataFrame for unified handling
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Failed to load CSV file '{batch_inputs}': {e}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if isinstance(batch_inputs, DataFrame):
|
|
117
|
+
# Handle DataFrame input
|
|
118
|
+
logger.debug(
|
|
119
|
+
f"Converting DataFrame ({len(batch_inputs)} rows) to batch inputs."
|
|
120
|
+
)
|
|
121
|
+
for _, row in batch_inputs.iterrows():
|
|
122
|
+
if input_mapping:
|
|
123
|
+
item_input = {
|
|
124
|
+
agent_key: row[df_col]
|
|
125
|
+
for df_col, agent_key in input_mapping.items()
|
|
126
|
+
if df_col in row
|
|
127
|
+
}
|
|
128
|
+
else:
|
|
129
|
+
item_input = row.to_dict()
|
|
130
|
+
prepared_batch_inputs.append(item_input)
|
|
131
|
+
else:
|
|
132
|
+
# Handle list of dictionaries
|
|
133
|
+
if not isinstance(batch_inputs, list):
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"batch_inputs must be a list of dictionaries, DataFrame, or CSV file path"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if input_mapping:
|
|
139
|
+
# Apply mapping to dictionary inputs
|
|
140
|
+
logger.debug("Applying input mapping to dictionary inputs")
|
|
141
|
+
for item in batch_inputs:
|
|
142
|
+
mapped_input = {}
|
|
143
|
+
for df_col, agent_key in input_mapping.items():
|
|
144
|
+
if df_col in item:
|
|
145
|
+
mapped_input[agent_key] = item[df_col]
|
|
146
|
+
else:
|
|
147
|
+
logger.warning(
|
|
148
|
+
f"Input mapping key '{df_col}' not found in input dictionary"
|
|
149
|
+
)
|
|
150
|
+
prepared_batch_inputs.append(mapped_input)
|
|
151
|
+
else:
|
|
152
|
+
# Use dictionaries as-is if no mapping provided
|
|
153
|
+
prepared_batch_inputs = batch_inputs
|
|
154
|
+
|
|
155
|
+
logger.debug(
|
|
156
|
+
f"Using provided list of {len(prepared_batch_inputs)} batch inputs."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if not prepared_batch_inputs:
|
|
160
|
+
return []
|
|
161
|
+
|
|
162
|
+
# --- Setup Progress Bar if Silent ---
|
|
163
|
+
progress_context = None
|
|
164
|
+
progress_task_id = None
|
|
165
|
+
if silent_mode:
|
|
166
|
+
progress = Progress(
|
|
167
|
+
SpinnerColumn(),
|
|
168
|
+
TextColumn("[progress.description]{task.description}"),
|
|
169
|
+
BarColumn(),
|
|
170
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
171
|
+
TextColumn("({task.completed}/{task.total})"),
|
|
172
|
+
TimeElapsedColumn(),
|
|
173
|
+
# transient=True # Optionally remove progress bar when done
|
|
174
|
+
)
|
|
175
|
+
progress_context = progress # Use as context manager
|
|
176
|
+
progress_task_id = progress.add_task(
|
|
177
|
+
f"Processing Batch ({exec_mode})",
|
|
178
|
+
total=len(prepared_batch_inputs),
|
|
179
|
+
)
|
|
180
|
+
progress.start()
|
|
181
|
+
|
|
182
|
+
results = [None] * len(
|
|
183
|
+
prepared_batch_inputs
|
|
184
|
+
) # Pre-allocate results list
|
|
185
|
+
tasks = []
|
|
186
|
+
semaphore = asyncio.Semaphore(
|
|
187
|
+
max_workers if parallel and not effective_use_temporal else 1
|
|
188
|
+
) # Semaphore for parallel local
|
|
189
|
+
|
|
190
|
+
async def worker(index, item_inputs):
|
|
191
|
+
async with semaphore:
|
|
192
|
+
full_input = {**(static_inputs or {}), **item_inputs}
|
|
193
|
+
context = FlockContext()
|
|
194
|
+
context.set_variable(FLOCK_BATCH_SILENT_MODE, silent_mode)
|
|
195
|
+
|
|
196
|
+
run_desc = f"Batch item {index + 1}"
|
|
197
|
+
logger.debug(f"{run_desc} started.")
|
|
198
|
+
try:
|
|
199
|
+
result = await self.run_async(
|
|
200
|
+
start_agent,
|
|
201
|
+
full_input,
|
|
202
|
+
box_result=box_results,
|
|
203
|
+
context=context,
|
|
204
|
+
)
|
|
205
|
+
results[index] = result
|
|
206
|
+
logger.debug(f"{run_desc} finished successfully.")
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.error(
|
|
209
|
+
f"{run_desc} failed: {e}", exc_info=not return_errors
|
|
210
|
+
)
|
|
211
|
+
if return_errors:
|
|
212
|
+
results[index] = e
|
|
213
|
+
else:
|
|
214
|
+
# If not returning errors, ensure the exception propagates
|
|
215
|
+
# to stop asyncio.gather if running in parallel.
|
|
216
|
+
if parallel and not effective_use_temporal:
|
|
217
|
+
raise # Re-raise to stop gather
|
|
218
|
+
else:
|
|
219
|
+
# For sequential, we just store None or the exception if return_errors=True
|
|
220
|
+
# For Temporal, error handling happens within the workflow/activity usually
|
|
221
|
+
results[index] = e if return_errors else None
|
|
222
|
+
finally:
|
|
223
|
+
if progress_context:
|
|
224
|
+
progress.update(
|
|
225
|
+
progress_task_id, advance=1
|
|
226
|
+
) # Update progress
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
if effective_use_temporal:
|
|
230
|
+
# Temporal Batching (Simplified: sequential execution for this example)
|
|
231
|
+
# A real implementation might use start_workflow or signals
|
|
232
|
+
logger.info(
|
|
233
|
+
"Running batch using Temporal (executing sequentially for now)..."
|
|
234
|
+
)
|
|
235
|
+
for i, item_data in enumerate(prepared_batch_inputs):
|
|
236
|
+
await worker(i, item_data) # Run sequentially for demo
|
|
237
|
+
# TODO: Implement true parallel Temporal workflow execution if needed
|
|
238
|
+
|
|
239
|
+
elif parallel:
|
|
240
|
+
logger.info(
|
|
241
|
+
f"Running batch in parallel with max_workers={max_workers}..."
|
|
242
|
+
)
|
|
243
|
+
for i, item_data in enumerate(prepared_batch_inputs):
|
|
244
|
+
tasks.append(asyncio.create_task(worker(i, item_data)))
|
|
245
|
+
await asyncio.gather(
|
|
246
|
+
*tasks
|
|
247
|
+
) # gather handles exceptions based on return_errors logic in worker
|
|
248
|
+
|
|
249
|
+
else: # Sequential Local
|
|
250
|
+
logger.info("Running batch sequentially...")
|
|
251
|
+
for i, item_data in enumerate(prepared_batch_inputs):
|
|
252
|
+
await worker(
|
|
253
|
+
i, item_data
|
|
254
|
+
) # Already handles errors internally based on return_errors
|
|
255
|
+
|
|
256
|
+
logger.info("Batch execution finished.")
|
|
257
|
+
|
|
258
|
+
except Exception as batch_error:
|
|
259
|
+
# This catch handles errors re-raised from workers when return_errors=False
|
|
260
|
+
logger.error(f"Batch execution stopped due to error: {batch_error}")
|
|
261
|
+
# No need to cancel tasks here as gather would have stopped
|
|
262
|
+
if not return_errors:
|
|
263
|
+
raise # Re-raise the first error encountered if not returning errors
|
|
264
|
+
finally:
|
|
265
|
+
if progress_context:
|
|
266
|
+
progress.stop()
|
|
267
|
+
|
|
268
|
+
if write_to_csv:
|
|
269
|
+
try:
|
|
270
|
+
df = pd.DataFrame(results)
|
|
271
|
+
# create write_to_csv directory if it doesn't exist
|
|
272
|
+
Path(write_to_csv).parent.mkdir(parents=True, exist_ok=True)
|
|
273
|
+
df.to_csv(write_to_csv, index=False)
|
|
274
|
+
logger.info(f"Results written to CSV file: {write_to_csv}")
|
|
275
|
+
except Exception as e:
|
|
276
|
+
logger.error(f"Failed to write results to CSV: {e}")
|
|
277
|
+
|
|
278
|
+
return results
|
|
279
|
+
|
|
280
|
+
def run_batch( # Synchronous wrapper
|
|
281
|
+
self,
|
|
282
|
+
start_agent: FlockAgent | str,
|
|
283
|
+
batch_inputs: list[dict[str, Any]] | DataFrame | str,
|
|
284
|
+
input_mapping: dict[str, str] | None = None,
|
|
285
|
+
static_inputs: dict[str, Any] | None = None,
|
|
286
|
+
parallel: bool = True,
|
|
287
|
+
max_workers: int = 5,
|
|
288
|
+
use_temporal: bool | None = None,
|
|
289
|
+
box_results: bool = True,
|
|
290
|
+
return_errors: bool = False,
|
|
291
|
+
silent_mode: bool = False,
|
|
292
|
+
write_to_csv: str | None = None,
|
|
293
|
+
) -> list[Box | dict | None | Exception]:
|
|
294
|
+
"""Synchronous wrapper for run_batch_async."""
|
|
295
|
+
# (Standard asyncio run wrapper - same as in previous suggestion)
|
|
296
|
+
try:
|
|
297
|
+
loop = asyncio.get_running_loop()
|
|
298
|
+
if loop.is_closed():
|
|
299
|
+
raise RuntimeError("Event loop is closed")
|
|
300
|
+
except RuntimeError:
|
|
301
|
+
loop = asyncio.new_event_loop()
|
|
302
|
+
asyncio.set_event_loop(loop)
|
|
303
|
+
|
|
304
|
+
coro = self.run_batch_async(
|
|
305
|
+
start_agent=start_agent,
|
|
306
|
+
batch_inputs=batch_inputs,
|
|
307
|
+
input_mapping=input_mapping,
|
|
308
|
+
static_inputs=static_inputs,
|
|
309
|
+
parallel=parallel,
|
|
310
|
+
max_workers=max_workers,
|
|
311
|
+
use_temporal=use_temporal,
|
|
312
|
+
box_results=box_results,
|
|
313
|
+
return_errors=return_errors,
|
|
314
|
+
silent_mode=silent_mode,
|
|
315
|
+
write_to_csv=write_to_csv,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if asyncio.get_event_loop() is loop and not loop.is_running():
|
|
319
|
+
results = loop.run_until_complete(coro)
|
|
320
|
+
# loop.close() # Avoid closing potentially shared loop
|
|
321
|
+
return results
|
|
322
|
+
else:
|
|
323
|
+
# Run within an existing loop
|
|
324
|
+
future = asyncio.ensure_future(coro)
|
|
325
|
+
return loop.run_until_complete(future)
|