sdg-hub 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdg_hub/flow_runner.py CHANGED
@@ -1,18 +1,31 @@
1
1
  """Script for running data generation flows with configurable parameters."""
2
2
 
3
3
  # Standard
4
+ from importlib import resources
5
+ from typing import Optional
4
6
  import os
7
+ import sys
8
+ import traceback
5
9
 
6
10
  # Third Party
7
11
  from datasets import load_dataset
8
12
  from openai import OpenAI
9
13
  import click
14
+ import yaml
10
15
 
11
16
  # First Party
12
17
  from sdg_hub.flow import Flow
13
18
  from sdg_hub.logger_config import setup_logger
14
19
  from sdg_hub.sdg import SDG
15
-
20
+ from sdg_hub.utils.error_handling import (
21
+ APIConnectionError,
22
+ DataGenerationError,
23
+ DataSaveError,
24
+ DatasetLoadError,
25
+ FlowConfigurationError,
26
+ FlowRunnerError,
27
+ )
28
+ from sdg_hub.utils.path_resolution import resolve_path
16
29
 
17
30
  logger = setup_logger(__name__)
18
31
 
@@ -28,7 +41,8 @@ def run_flow(
28
41
  save_freq: int = 2,
29
42
  debug: bool = False,
30
43
  dataset_start_index: int = 0,
31
- dataset_end_index: int = None,
44
+ dataset_end_index: Optional[int] = None,
45
+ api_key: Optional[str] = None,
32
46
  ) -> None:
33
47
  """Process the dataset using the specified configuration.
34
48
 
@@ -52,6 +66,12 @@ def run_flow(
52
66
  Frequency (in batches) at which to save checkpoints, by default 2.
53
67
  debug : bool, optional
54
68
  If True, enables debug mode with a smaller dataset subset, by default False.
69
+ dataset_start_index : int, optional
70
+ Start index for dataset slicing, by default 0.
71
+ dataset_end_index : Optional[int], optional
72
+ End index for dataset slicing, by default None.
73
+ api_key : Optional[str], optional
74
+ API key for the remote endpoint. If not provided, will use OPENAI_API_KEY environment variable, by default None.
55
75
 
56
76
  Returns
57
77
  -------
@@ -59,42 +79,214 @@ def run_flow(
59
79
 
60
80
  Raises
61
81
  ------
62
- FileNotFoundError
63
- If the flow configuration file is not found.
82
+ DatasetLoadError
83
+ If the dataset cannot be loaded or processed.
84
+ FlowConfigurationError
85
+ If the flow configuration is invalid or cannot be loaded.
86
+ APIConnectionError
87
+ If connection to the API endpoint fails.
88
+ DataGenerationError
89
+ If data generation fails during processing.
90
+ DataSaveError
91
+ If saving the generated data fails.
64
92
  """
65
93
  logger.info(f"Generation configuration: {locals()}\n\n")
66
- ds = load_dataset("json", data_files=ds_path, split="train")
67
- if dataset_start_index is not None and dataset_end_index is not None:
68
- ds = ds.select(range(dataset_start_index, dataset_end_index))
69
- logger.info(f"Dataset sliced from {dataset_start_index} to {dataset_end_index}")
70
- if debug:
71
- ds = ds.shuffle(seed=42).select(range(30))
72
- logger.info("Debug mode enabled. Using a subset of the dataset.")
73
-
74
- openai_api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
75
- openai_api_base = endpoint
76
-
77
- client = OpenAI(
78
- api_key=openai_api_key,
79
- base_url=openai_api_base,
80
- )
81
-
82
- if not os.path.exists(flow_path):
83
- raise FileNotFoundError(f"Flow file not found: {flow_path}")
84
-
85
- flow = Flow(client).get_flow_from_file(flow_path)
86
- sdg = SDG(
87
- flows=[flow],
88
- num_workers=num_workers,
89
- batch_size=batch_size,
90
- save_freq=save_freq,
91
- )
92
-
93
- generated_data = sdg.generate(ds, checkpoint_dir=checkpoint_dir)
94
- if dataset_end_index is not None and dataset_start_index is not None:
95
- save_path = save_path.replace(".jsonl", f"_{dataset_start_index}_{dataset_end_index}.jsonl")
96
- generated_data.to_json(save_path, orient="records", lines=True)
97
- logger.info(f"Data saved to {save_path}")
94
+
95
+ try:
96
+ # Load and validate dataset
97
+ try:
98
+ ds = load_dataset("json", data_files=ds_path, split="train")
99
+ logger.info(
100
+ f"Successfully loaded dataset from {ds_path} with {len(ds)} rows"
101
+ )
102
+ except Exception as e:
103
+ raise DatasetLoadError(
104
+ f"Failed to load dataset from '{ds_path}'. "
105
+ f"Please check if the file exists and is a valid JSON file.",
106
+ details=str(e),
107
+ ) from e
108
+
109
+ # Apply dataset slicing if specified
110
+ try:
111
+ if dataset_start_index is not None and dataset_end_index is not None:
112
+ if dataset_start_index >= len(ds) or dataset_end_index > len(ds):
113
+ raise DatasetLoadError(
114
+ f"Dataset slice indices ({dataset_start_index}, {dataset_end_index}) "
115
+ f"are out of bounds for dataset with {len(ds)} rows"
116
+ )
117
+ if dataset_start_index >= dataset_end_index:
118
+ raise DatasetLoadError(
119
+ f"Start index ({dataset_start_index}) must be less than end index ({dataset_end_index})"
120
+ )
121
+ ds = ds.select(range(dataset_start_index, dataset_end_index))
122
+ logger.info(
123
+ f"Dataset sliced from {dataset_start_index} to {dataset_end_index}"
124
+ )
125
+
126
+ if debug:
127
+ if len(ds) < 30:
128
+ logger.warning(
129
+ f"Debug mode requested 30 samples but dataset only has {len(ds)} rows"
130
+ )
131
+ ds = ds.shuffle(seed=42).select(range(min(30, len(ds))))
132
+ logger.info(
133
+ f"Debug mode enabled. Using {len(ds)} samples from the dataset."
134
+ )
135
+ except DatasetLoadError:
136
+ raise
137
+ except Exception as e:
138
+ raise DatasetLoadError(
139
+ "Failed to process dataset slicing or debug mode.", details=str(e)
140
+ ) from e
141
+
142
+ # Validate API configuration
143
+ openai_api_key = api_key or os.environ.get("OPENAI_API_KEY")
144
+ if not openai_api_key or openai_api_key == "EMPTY":
145
+ logger.warning("API key not provided and OPENAI_API_KEY not set or is 'EMPTY'. API calls may fail.")
146
+
147
+ openai_api_base = endpoint
148
+ if not openai_api_base:
149
+ raise APIConnectionError("API endpoint cannot be empty")
150
+
151
+ # Initialize OpenAI client
152
+ try:
153
+ client = OpenAI(
154
+ api_key=openai_api_key or "EMPTY",
155
+ base_url=openai_api_base,
156
+ )
157
+ # test connection with a model list
158
+ models = client.models.list()
159
+ logger.info(f"Initialized OpenAI client with endpoint: {openai_api_base}")
160
+ logger.info(f"Available models: {[model.id for model in models.data]}")
161
+ except Exception as e:
162
+ raise APIConnectionError(
163
+ f"Failed to initialize OpenAI client with endpoint '{openai_api_base}'. "
164
+ f"Please check if the endpoint is valid and accessible.",
165
+ details=str(e),
166
+ ) from e
167
+
168
+ # Load and validate flow configuration
169
+ try:
170
+ base_path = str(resources.files(__package__))
171
+ flow_path = resolve_path(flow_path, [".", base_path])
172
+ if not os.path.exists(flow_path):
173
+ raise FlowConfigurationError(
174
+ f"Flow configuration file not found: {flow_path}"
175
+ )
176
+
177
+ # Validate flow file is readable YAML
178
+ try:
179
+ with open(flow_path, "r", encoding="utf-8") as f:
180
+ flow_config = yaml.safe_load(f)
181
+ if not flow_config:
182
+ raise FlowConfigurationError(
183
+ f"Flow configuration file is empty: {flow_path}"
184
+ )
185
+ logger.info(f"Successfully loaded flow configuration from {flow_path}")
186
+ except yaml.YAMLError as e:
187
+ raise FlowConfigurationError(
188
+ f"Flow configuration file '{flow_path}' contains invalid YAML.",
189
+ details=str(e),
190
+ ) from e
191
+ except Exception as e:
192
+ raise FlowConfigurationError(
193
+ f"Failed to read flow configuration file '{flow_path}'.",
194
+ details=str(e),
195
+ ) from e
196
+
197
+ flow = Flow(client).get_flow_from_file(flow_path)
198
+ logger.info("Successfully initialized flow from configuration")
199
+ except FlowConfigurationError:
200
+ raise
201
+ except Exception as e:
202
+ raise FlowConfigurationError(
203
+ f"Failed to create flow from configuration file '{flow_path}'. "
204
+ f"Please check the flow configuration format and block definitions.",
205
+ details=str(e),
206
+ ) from e
207
+
208
+ # Initialize SDG and generate data
209
+ try:
210
+ sdg = SDG(
211
+ flows=[flow],
212
+ num_workers=num_workers,
213
+ batch_size=batch_size,
214
+ save_freq=save_freq,
215
+ )
216
+ logger.info(
217
+ f"Initialized SDG with {num_workers} workers, batch size {batch_size}"
218
+ )
219
+
220
+ # Ensure checkpoint directory exists if specified
221
+ if checkpoint_dir and not os.path.exists(checkpoint_dir):
222
+ os.makedirs(checkpoint_dir, exist_ok=True)
223
+ logger.info(f"Created checkpoint directory: {checkpoint_dir}")
224
+
225
+ generated_data = sdg.generate(ds, checkpoint_dir=checkpoint_dir)
226
+
227
+ if generated_data is None or len(generated_data) == 0:
228
+ raise DataGenerationError(
229
+ "Data generation completed but no data was generated. "
230
+ "This may indicate issues with the flow configuration or input data."
231
+ )
232
+
233
+ logger.info(f"Successfully generated {len(generated_data)} rows of data")
234
+
235
+ except Exception as e:
236
+ if isinstance(e, DataGenerationError):
237
+ raise
238
+ raise DataGenerationError(
239
+ "Data generation failed during processing. This could be due to:"
240
+ "\n- API connection issues with the endpoint"
241
+ "\n- Invalid flow configuration or block parameters"
242
+ "\n- Insufficient system resources (try reducing batch_size or num_workers)"
243
+ "\n- Input data format incompatibility",
244
+ details=f"Endpoint: {openai_api_base}, Error: {e}",
245
+ ) from e
246
+
247
+ # Save generated data
248
+ try:
249
+ # Adjust save path for dataset slicing
250
+ final_save_path = save_path
251
+ if dataset_end_index is not None and dataset_start_index is not None:
252
+ final_save_path = save_path.replace(
253
+ ".jsonl", f"_{dataset_start_index}_{dataset_end_index}.jsonl"
254
+ )
255
+
256
+ # Ensure save directory exists
257
+ save_dir = os.path.dirname(final_save_path)
258
+ if save_dir and not os.path.exists(save_dir):
259
+ os.makedirs(save_dir, exist_ok=True)
260
+ logger.info(f"Created save directory: {save_dir}")
261
+
262
+ generated_data.to_json(final_save_path, orient="records", lines=True)
263
+ logger.info(f"Data successfully saved to {final_save_path}")
264
+
265
+ except Exception as e:
266
+ raise DataSaveError(
267
+ f"Failed to save generated data to '{final_save_path}'. "
268
+ f"Please check write permissions and disk space.",
269
+ details=str(e),
270
+ ) from e
271
+
272
+ except (
273
+ DatasetLoadError,
274
+ FlowConfigurationError,
275
+ APIConnectionError,
276
+ DataGenerationError,
277
+ DataSaveError,
278
+ ):
279
+ # Re-raise our custom exceptions with their detailed messages
280
+ raise
281
+ except Exception as e:
282
+ # Catch any unexpected errors
283
+ logger.error(f"Unexpected error during flow execution: {e}")
284
+ logger.error(f"Traceback: {traceback.format_exc()}")
285
+ raise FlowRunnerError(
286
+ "An unexpected error occurred during flow execution. "
287
+ "Please check the logs for more details.",
288
+ details=str(e),
289
+ ) from e
98
290
 
99
291
 
100
292
  @click.command()
@@ -154,8 +346,18 @@ def run_flow(
154
346
  is_flag=True,
155
347
  help="Enable debug mode with a smaller dataset subset.",
156
348
  )
157
- @click.option("--dataset_start_index", type=int, default=0, help="Start index of the dataset.")
158
- @click.option("--dataset_end_index", type=int, default=None, help="End index of the dataset.")
349
+ @click.option(
350
+ "--dataset_start_index", type=int, default=0, help="Start index of the dataset."
351
+ )
352
+ @click.option(
353
+ "--dataset_end_index", type=int, default=None, help="End index of the dataset."
354
+ )
355
+ @click.option(
356
+ "--api_key",
357
+ type=str,
358
+ default=None,
359
+ help="API key for the remote endpoint. If not provided, will use OPENAI_API_KEY environment variable.",
360
+ )
159
361
  def main(
160
362
  ds_path: str,
161
363
  bs: int,
@@ -167,7 +369,8 @@ def main(
167
369
  save_freq: int,
168
370
  debug: bool,
169
371
  dataset_start_index: int,
170
- dataset_end_index: int,
372
+ dataset_end_index: Optional[int],
373
+ api_key: Optional[str],
171
374
  ) -> None:
172
375
  """CLI entry point for running data generation flows.
173
376
 
@@ -191,24 +394,55 @@ def main(
191
394
  Frequency (in batches) at which to save checkpoints.
192
395
  debug : bool
193
396
  If True, enables debug mode with a smaller dataset subset.
397
+ dataset_start_index : int
398
+ Start index for dataset slicing.
399
+ dataset_end_index : Optional[int]
400
+ End index for dataset slicing.
401
+ api_key : Optional[str]
402
+ API key for the remote endpoint. If not provided, will use OPENAI_API_KEY environment variable.
194
403
 
195
404
  Returns
196
405
  -------
197
406
  None
198
407
  """
199
- run_flow(
200
- ds_path=ds_path,
201
- batch_size=bs,
202
- num_workers=num_workers,
203
- save_path=save_path,
204
- endpoint=endpoint,
205
- flow_path=flow,
206
- checkpoint_dir=checkpoint_dir,
207
- save_freq=save_freq,
208
- debug=debug,
209
- dataset_start_index=dataset_start_index,
210
- dataset_end_index=dataset_end_index,
211
- )
408
+ try:
409
+ run_flow(
410
+ ds_path=ds_path,
411
+ batch_size=bs,
412
+ num_workers=num_workers,
413
+ save_path=save_path,
414
+ endpoint=endpoint,
415
+ flow_path=flow,
416
+ checkpoint_dir=checkpoint_dir,
417
+ save_freq=save_freq,
418
+ debug=debug,
419
+ dataset_start_index=dataset_start_index,
420
+ dataset_end_index=dataset_end_index,
421
+ api_key=api_key,
422
+ )
423
+ except (
424
+ DatasetLoadError,
425
+ FlowConfigurationError,
426
+ APIConnectionError,
427
+ DataGenerationError,
428
+ DataSaveError,
429
+ FlowRunnerError,
430
+ ) as e:
431
+ logger.error(f"Flow execution failed: {e}")
432
+ click.echo(f"Error: {e}", err=True)
433
+ sys.exit(1)
434
+ except KeyboardInterrupt:
435
+ logger.info("Flow execution interrupted by user")
436
+ click.echo("Flow execution interrupted by user", err=True)
437
+ sys.exit(130) # Standard exit code for SIGINT
438
+ except Exception as e:
439
+ logger.error(f"Unexpected error: {e}")
440
+ logger.error(f"Traceback: {traceback.format_exc()}")
441
+ click.echo(
442
+ f"Unexpected error occurred. Please check the logs for details. Error: {e}",
443
+ err=True,
444
+ )
445
+ sys.exit(1)
212
446
 
213
447
 
214
448
  if __name__ == "__main__":
@@ -2,7 +2,7 @@
2
2
  block_config:
3
3
  block_name: gen_mmlu_knowledge
4
4
  config_path: configs/knowledge/mcq_generation.yaml
5
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
5
+ model_id: meta-llama/Llama-3.3-70B-Instruct
6
6
  output_cols:
7
7
  - mmlubench_question
8
8
  - mmlubench_answer
@@ -2,7 +2,7 @@
2
2
  block_config:
3
3
  block_name: gen_knowledge
4
4
  config_path: configs/knowledge/simple_generate_qa.yaml
5
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
5
+ model_id: meta-llama/Llama-3.3-70B-Instruct
6
6
  output_cols:
7
7
  - output
8
8
  gen_kwargs:
@@ -2,7 +2,7 @@
2
2
  block_config:
3
3
  block_name: gen_knowledge
4
4
  config_path: configs/knowledge/generate_questions_responses.yaml
5
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
5
+ model_id: meta-llama/Llama-3.3-70B-Instruct
6
6
  output_cols:
7
7
  - question
8
8
  - response
@@ -20,7 +20,7 @@
20
20
  block_config:
21
21
  block_name: eval_faithfulness_qa_pair
22
22
  config_path: configs/knowledge/evaluate_faithfulness.yaml
23
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
23
+ model_id: meta-llama/Llama-3.3-70B-Instruct
24
24
  output_cols:
25
25
  - explanation
26
26
  - judgment
@@ -43,7 +43,7 @@
43
43
  block_config:
44
44
  block_name: eval_relevancy_qa_pair
45
45
  config_path: configs/knowledge/evaluate_relevancy.yaml
46
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
46
+ model_id: meta-llama/Llama-3.3-70B-Instruct
47
47
  output_cols:
48
48
  - feedback
49
49
  - score
@@ -67,7 +67,7 @@
67
67
  block_config:
68
68
  block_name: eval_verify_question
69
69
  config_path: configs/knowledge/evaluate_question.yaml
70
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
70
+ model_id: meta-llama/Llama-3.3-70B-Instruct
71
71
  output_cols:
72
72
  - explanation
73
73
  - rating
@@ -8,38 +8,35 @@
8
8
  block_config:
9
9
  block_name: gen_detailed_summary
10
10
  config_path: configs/knowledge/detailed_summary.yaml
11
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
11
+ model_id: meta-llama/Llama-3.3-70B-Instruct
12
12
  output_cols:
13
13
  - summary_detailed
14
14
  gen_kwargs:
15
15
  max_tokens: 4096
16
16
  temperature: 0.7
17
- seed: 7452
18
17
  n: 50
19
18
 
20
19
  - block_type: LLMBlock
21
20
  block_config:
22
21
  block_name: gen_atomic_facts
23
22
  config_path: configs/knowledge/atomic_facts.yaml
24
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
23
+ model_id: meta-llama/Llama-3.3-70B-Instruct
25
24
  output_cols:
26
25
  - summary_atomic_facts
27
26
  gen_kwargs:
28
27
  max_tokens: 4096
29
28
  temperature: 0.7
30
- seed: 7452
31
29
 
32
30
  - block_type: LLMBlock
33
31
  block_config:
34
32
  block_name: gen_extractive_summary
35
33
  config_path: configs/knowledge/extractive_summary.yaml
36
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
34
+ model_id: meta-llama/Llama-3.3-70B-Instruct
37
35
  output_cols:
38
36
  - summary_extractive
39
37
  gen_kwargs:
40
38
  max_tokens: 4096
41
39
  temperature: 0.7
42
- seed: 7452
43
40
 
44
41
  - block_type: FlattenColumnsBlock
45
42
  block_config:
@@ -63,7 +60,7 @@
63
60
  block_config:
64
61
  block_name: knowledge generation
65
62
  config_path: configs/knowledge/generate_questions.yaml
66
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
63
+ model_id: meta-llama/Llama-3.3-70B-Instruct
67
64
  output_cols:
68
65
  - question
69
66
  parser_kwargs:
@@ -72,25 +69,23 @@
72
69
  gen_kwargs:
73
70
  temperature: 0.7
74
71
  max_tokens: 100
75
- seed: 7452
76
72
 
77
73
  - block_type: LLMBlock
78
74
  block_config:
79
75
  block_name: knowledge generation
80
76
  config_path: configs/knowledge/generate_responses.yaml
81
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
77
+ model_id: meta-llama/Llama-3.3-70B-Instruct
82
78
  output_cols:
83
79
  - response
84
80
  gen_kwargs:
85
81
  temperature: 0.7
86
82
  max_tokens: 2048
87
- seed: 7452
88
83
 
89
84
  - block_type: LLMBlock
90
85
  block_config:
91
86
  block_name: eval_faithfulness_qa_pair
92
87
  config_path: configs/knowledge/evaluate_faithfulness.yaml
93
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
88
+ model_id: meta-llama/Llama-3.3-70B-Instruct
94
89
  output_cols:
95
90
  - explanation
96
91
  - judgment
@@ -111,7 +106,7 @@
111
106
  block_config:
112
107
  block_name: eval_relevancy_qa_pair
113
108
  config_path: configs/knowledge/evaluate_relevancy.yaml
114
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
109
+ model_id: meta-llama/Llama-3.3-70B-Instruct
115
110
  output_cols:
116
111
  - feedback
117
112
  - score
@@ -133,7 +128,7 @@
133
128
  block_config:
134
129
  block_name: eval_verify_question
135
130
  config_path: configs/knowledge/evaluate_question.yaml
136
- model_id: mistralai/Mixtral-8x7B-Instruct-v0.1
131
+ model_id: meta-llama/Llama-3.3-70B-Instruct
137
132
  output_cols:
138
133
  - explanation
139
134
  - rating
sdg_hub/prompts.py CHANGED
@@ -28,6 +28,11 @@ def microsoft_phi_chat_template():
28
28
 
29
29
  @PromptRegistry.register("nvidia/Llama-3_3-Nemotron-Super-49B-v1")
30
30
  def nemotron_chat_template():
31
+ """
32
+ Format chat messages for the Nemotron model, including a system prompt and structured message headers.
33
+
34
+ The template starts with a system message containing "detailed thinking on", then iterates over messages, wrapping each with start and end header tokens and an end-of-text token. For assistant messages containing a `</think>` tag, only the content after this tag is included. Optionally appends an assistant prompt if generation is requested.
35
+ """
31
36
  return """{{- bos_token }}
32
37
  {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}detailed thinking on{{- "<|eot_id|>" }}
33
38
  {%- for message in messages %}
@@ -41,3 +46,29 @@ def nemotron_chat_template():
41
46
  {%- if add_generation_prompt %}
42
47
  {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
43
48
  {%- endif %}"""
49
+
50
+
51
+ @PromptRegistry.register("Qwen/Qwen2.5")
52
+ def qwen_2_5_chat_template():
53
+ """
54
+ Formats chat messages into the prompt structure required by the Qwen 2.5 model family, supporting system messages, tool descriptions, function call instructions, and role-based message formatting.
55
+
56
+ If tools are provided, includes tool signatures and instructions for function calls in the system prompt. User, assistant, and tool messages are wrapped with special tokens, and assistant tool calls are serialized as JSON within XML tags. Optionally appends a generation prompt for the assistant.
57
+ """
58
+ return """{%- if tools %}\n {{- \'<|im_start|>system\\n\' }}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- messages[0][\'content\'] }}\n {%- else %}\n {{- \'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\' }}\n {%- endif %}\n {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n {%- for tool in tools %}\n {{- "\\n" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}\n {%- else %}\n {{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}\n {%- elif message.role == "assistant" %}\n {{- \'<|im_start|>\' + message.role }}\n {%- if message.content %}\n {{- \'\\n\' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- \'\\n<tool_call>\\n{"name": "\' }}\n {{- tool_call.name }}\n {{- \'", "arguments": \' }}\n {{- tool_call.arguments | tojson }}\n {{- \'}\\n</tool_call>\' }}\n {%- endfor %}\n {{- \'<|im_end|>\\n\' }}\n {%- elif message.role == "tool" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}\n {{- \'<|im_start|>user\' }}\n {%- endif %}\n {{- \'\\n<tool_response>\\n\' }}\n {{- message.content }}\n {{- \'\\n</tool_response>\' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}\n {{- \'<|im_end|>\\n\' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|im_start|>assistant\\n\' }}\n{%- endif %}\n"""
59
+
60
+
61
+ @PromptRegistry.register("Qwen/Qwen3")
62
+ def qwen_3_chat_template():
63
+ """
64
+ Formats chat messages for the Qwen 3 model family, supporting multi-step tool usage, reasoning content, and special XML tags for tool calls and responses.
65
+
66
+ This template handles system messages, user and assistant roles, and tool interactions. When tools are provided, it outputs their signatures and instructions for function calls. It tracks the last user query to determine where to insert assistant reasoning content within `<think>` tags. Assistant tool calls are serialized as JSON within `<tool_call>` tags, and tool responses are grouped inside `<tool_response>` tags. Optionally, a generation prompt and empty reasoning block can be added.
67
+
68
+ Parameters:
69
+ tools (optional): List of tool signature objects to be included in the prompt.
70
+ messages: List of message objects, each with a role and content, and optionally tool_calls or reasoning_content.
71
+ add_generation_prompt (optional): If true, appends an assistant prompt for generation.
72
+ enable_thinking (optional): If false, inserts an empty reasoning block in the assistant prompt.
73
+ """
74
+ return """{%- if tools %}\n {{- \'<|im_start|>system\\n\' }}\n {%- if messages[0].role == \'system\' %}\n {{- messages[0].content + \'\\n\\n\' }}\n {%- endif %}\n {{- "# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n {%- for tool in tools %}\n {{- "\\n" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n {%- if messages[0].role == \'system\' %}\n {{- \'<|im_start|>system\\n\' + messages[0].content + \'<|im_end|>\\n\' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith(\'<tool_response>\') and message.content.endswith(\'</tool_response>\')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = \'\' %}\n {%- endif %}\n {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + content + \'<|im_end|>\' + \'\\n\' }}\n {%- elif message.role == "assistant" %}\n {%- set reasoning_content = \'\' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if \'</think>\' in content %}\n {%- set reasoning_content = content.split(\'</think>\')[0].rstrip(\'\\n\').split(\'<think>\')[-1].lstrip(\'\\n\') %}\n {%- set content = content.split(\'</think>\')[-1].lstrip(\'\\n\') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- \'<|im_start|>\' + message.role + \'\\n<think>\\n\' + reasoning_content.strip(\'\\n\') + \'\\n</think>\\n\\n\' + content.lstrip(\'\\n\') }}\n {%- else %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + content }}\n {%- endif %}\n {%- else %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- \'\\n\' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- \'<tool_call>\\n{"name": "\' }}\n {{- tool_call.name }}\n {{- \'", "arguments": \' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- \'}\\n</tool_call>\' }}\n {%- endfor %}\n {%- endif %}\n {{- \'<|im_end|>\\n\' }}\n {%- elif message.role == "tool" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}\n {{- \'<|im_start|>user\' }}\n {%- endif %}\n {{- \'\\n<tool_response>\\n\' }}\n {{- content }}\n {{- \'\\n</tool_response>\' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}\n {{- \'<|im_end|>\\n\' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|im_start|>assistant\\n\' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- \'<think>\\n\\n</think>\\n\\n\' }}\n {%- endif %}\n{%- endif %}"""
sdg_hub/utils/__init__.py CHANGED
@@ -3,3 +3,8 @@
3
3
  # This is part of the public API, and used by instructlab
4
4
  class GenerateException(Exception):
5
5
  """An exception raised during generate step."""
6
+
7
+
8
+ from .path_resolution import resolve_path
9
+
10
+ __all__ = ["GenerateException", "resolve_path"]