DeepFabric 4.4.0__py3-none-any.whl → 4.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import logging
3
+ import sys
3
4
 
4
5
  from typing import Any
5
6
 
@@ -36,9 +37,15 @@ class TransformersBackend(InferenceBackend):
36
37
  """
37
38
  super().__init__(config)
38
39
 
40
+ # Check if model is pre-loaded (not a string path)
41
+ is_preloaded = not isinstance(config.model, str)
42
+
39
43
  # Determine device
40
44
  if config.device:
41
45
  self.device = config.device
46
+ elif is_preloaded:
47
+ # Get device from pre-loaded model
48
+ self.device = str(next(config.model.parameters()).device)
42
49
  # Auto-detect best available device
43
50
  elif torch.cuda.is_available():
44
51
  self.device = "cuda"
@@ -48,7 +55,7 @@ class TransformersBackend(InferenceBackend):
48
55
  self.device = "cpu"
49
56
 
50
57
  # Determine dtype based on device
51
- if self.device == "cuda":
58
+ if self.device == "cuda" or self.device.startswith("cuda:"):
52
59
  dtype = torch.float16
53
60
  device_map = "auto"
54
61
  elif self.device == "mps":
@@ -58,11 +65,36 @@ class TransformersBackend(InferenceBackend):
58
65
  dtype = torch.float32
59
66
  device_map = None
60
67
 
68
+ # Handle pre-loaded model case - skip all loading logic
69
+ if is_preloaded:
70
+ self.model = config.model
71
+ self.tokenizer = config.tokenizer
72
+ self.loaded_with_unsloth = False
73
+
74
+ # Detect architecture from pre-loaded model's config
75
+ self._architectures = []
76
+ if hasattr(self.model, "config"):
77
+ self._architectures = getattr(self.model.config, "architectures", []) or []
78
+
79
+ # Initialize tool call parser
80
+ self._tool_call_parser: ToolCallParser = get_parser(self._architectures)
81
+ logger.info(
82
+ "Using pre-loaded model with %s parser for architectures: %s",
83
+ type(self._tool_call_parser).__name__,
84
+ self._architectures or ["unknown"],
85
+ )
86
+
87
+ # Set padding token if not set
88
+ if self.tokenizer.pad_token is None:
89
+ self.tokenizer.pad_token = self.tokenizer.eos_token
90
+
91
+ return # Skip remaining initialization
92
+
61
93
  # Detect model architecture for parser selection and tokenizer config
62
- self._architectures: list[str] = []
94
+ self._architectures = []
63
95
  tokenizer_kwargs: dict[str, Any] = {}
64
96
  try:
65
- model_config = AutoConfig.from_pretrained(config.model_path) # nosec
97
+ model_config = AutoConfig.from_pretrained(config.model) # nosec
66
98
  self._architectures = getattr(model_config, "architectures", []) or []
67
99
  if any(arch in MISTRAL_ARCHITECTURES for arch in self._architectures):
68
100
  tokenizer_kwargs["fix_mistral_regex"] = True
@@ -71,7 +103,7 @@ class TransformersBackend(InferenceBackend):
71
103
  logger.warning("Could not detect model architecture: %s", e)
72
104
 
73
105
  # Initialize tool call parser based on detected architecture
74
- self._tool_call_parser: ToolCallParser = get_parser(self._architectures)
106
+ self._tool_call_parser = get_parser(self._architectures)
75
107
  logger.info(
76
108
  "Using %s for model architectures: %s",
77
109
  type(self._tool_call_parser).__name__,
@@ -79,19 +111,44 @@ class TransformersBackend(InferenceBackend):
79
111
  )
80
112
 
81
113
  self.loaded_with_unsloth = False
82
- # Load with Unsloth if requested
83
- if config.use_unsloth:
114
+
115
+ # Detect if Unsloth has already patched the environment
116
+ # This happens when user imports unsloth in the same runtime
117
+ unsloth_patched = "unsloth" in sys.modules
118
+
119
+ # Use Unsloth if explicitly requested OR if Unsloth has patched the environment
120
+ # (to avoid "apply_qkv" errors from patched attention classes)
121
+ use_unsloth_loading = config.use_unsloth or unsloth_patched
122
+
123
+ if use_unsloth_loading:
84
124
  try:
85
125
  from unsloth import FastLanguageModel # type: ignore # noqa: PLC0415
86
126
 
87
- # Load from adapter path if provided, otherwise from model_path
88
- load_path = config.adapter_path if config.adapter_path else config.model_path
89
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
90
- model_name=load_path,
91
- max_seq_length=config.max_seq_length,
92
- dtype=dtype,
93
- load_in_4bit=config.load_in_4bit,
94
- )
127
+ if unsloth_patched and not config.use_unsloth:
128
+ logger.info(
129
+ "Unsloth detected in environment, using Unsloth loader for compatibility"
130
+ )
131
+
132
+ if config.adapter_path:
133
+ # Load base model first, then apply adapter
134
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
135
+ model_name=config.model,
136
+ max_seq_length=config.max_seq_length,
137
+ dtype=dtype,
138
+ load_in_4bit=config.load_in_4bit,
139
+ )
140
+ # Load LoRA adapter using PEFT
141
+ from peft import PeftModel # noqa: PLC0415
142
+
143
+ self.model = PeftModel.from_pretrained(self.model, config.adapter_path)
144
+ else:
145
+ # Load merged model or base model directly
146
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
147
+ model_name=config.model,
148
+ max_seq_length=config.max_seq_length,
149
+ dtype=dtype,
150
+ load_in_4bit=config.load_in_4bit,
151
+ )
95
152
  FastLanguageModel.for_inference(self.model)
96
153
  self.loaded_with_unsloth = True
97
154
  except ImportError:
@@ -104,11 +161,11 @@ class TransformersBackend(InferenceBackend):
104
161
  # Standard transformers/PEFT loading
105
162
  if not self.loaded_with_unsloth:
106
163
  self.tokenizer = AutoTokenizer.from_pretrained( # nosec
107
- config.model_path, **tokenizer_kwargs
164
+ config.model, **tokenizer_kwargs
108
165
  )
109
166
 
110
167
  self.model = AutoModelForCausalLM.from_pretrained( # nosec
111
- config.model_path,
168
+ config.model,
112
169
  device_map=device_map,
113
170
  dtype=dtype,
114
171
  )
@@ -36,12 +36,12 @@ class EvaluatorConfig(BaseModel):
36
36
  default=None,
37
37
  description="Path to save evaluation results",
38
38
  )
39
- model_path: str | None = Field(
39
+ model: str | None = Field(
40
40
  default=None,
41
- description="Path to model to evaluate (overrides inference_config.model_path)",
41
+ description="Model to evaluate (overrides inference_config.model)",
42
42
  )
43
43
  inference_config: InferenceConfig = Field(
44
- description="Inference backend configuration (includes model_path)",
44
+ description="Inference backend configuration (includes model)",
45
45
  )
46
46
  batch_size: int = Field(
47
47
  default=1,
@@ -119,7 +119,7 @@ class Evaluator:
119
119
  "evaluator_created",
120
120
  {
121
121
  "backend": self.config.inference_config.backend,
122
- "model_path": self.config.inference_config.model_path,
122
+ "model": self.config.inference_config.model,
123
123
  "has_adapter": self.config.inference_config.adapter_path is not None,
124
124
  "evaluators": (
125
125
  list(self.config.evaluators)
@@ -434,6 +434,7 @@ class Evaluator:
434
434
  ground_truth=ground_truth,
435
435
  response=response,
436
436
  evaluator_results=evaluator_results,
437
+ tools=tools,
437
438
  )
438
439
 
439
440
  except Exception as e: # noqa: BLE001
@@ -442,8 +443,9 @@ class Evaluator:
442
443
  expected_tool = None
443
444
  expected_params: dict[str, Any] = {}
444
445
  expected_answer = None
446
+ available_tool_names: list[str] = []
445
447
 
446
- # Try to extract ground truth if available
448
+ # Try to extract ground truth and tools if available
447
449
  try:
448
450
  gt = self.extract_ground_truth(sample)
449
451
  query = gt.query
@@ -453,9 +455,16 @@ class Evaluator:
453
455
  except (KeyError, AttributeError, ValidationError):
454
456
  pass
455
457
 
458
+ try:
459
+ tools = self.prepare_tools(sample)
460
+ available_tool_names = [t.name for t in tools]
461
+ except (KeyError, AttributeError, ValidationError):
462
+ pass
463
+
456
464
  return SampleEvaluation(
457
465
  sample_id=sample_id,
458
466
  query=query,
467
+ available_tools=available_tool_names,
459
468
  expected_tool=expected_tool,
460
469
  predicted_tool=None,
461
470
  expected_parameters=expected_params,
@@ -560,6 +569,7 @@ class Evaluator:
560
569
  ground_truth=ground_truth,
561
570
  predicted_tool_calls=all_predicted_tool_calls,
562
571
  final_content=final_content,
572
+ tools=tools,
563
573
  )
564
574
 
565
575
  except Exception as e: # noqa: BLE001
@@ -568,6 +578,7 @@ class Evaluator:
568
578
  expected_tool = None
569
579
  expected_params: dict[str, Any] = {}
570
580
  expected_answer = None
581
+ available_tool_names: list[str] = []
571
582
 
572
583
  try:
573
584
  gt = self.extract_ground_truth(sample)
@@ -578,9 +589,16 @@ class Evaluator:
578
589
  except (KeyError, AttributeError, ValidationError):
579
590
  pass
580
591
 
592
+ try:
593
+ tools = self.prepare_tools(sample)
594
+ available_tool_names = [t.name for t in tools]
595
+ except (KeyError, AttributeError, ValidationError):
596
+ pass
597
+
581
598
  return SampleEvaluation(
582
599
  sample_id=sample_id,
583
600
  query=query,
601
+ available_tools=available_tool_names,
584
602
  expected_tool=expected_tool,
585
603
  predicted_tool=None,
586
604
  expected_parameters=expected_params,
@@ -600,6 +618,7 @@ class Evaluator:
600
618
  ground_truth: GroundTruth,
601
619
  predicted_tool_calls: list[dict],
602
620
  final_content: str,
621
+ tools: list[ToolDefinition] | None = None,
603
622
  ) -> SampleEvaluation:
604
623
  """Compute metrics for multi-turn evaluation.
605
624
 
@@ -610,6 +629,7 @@ class Evaluator:
610
629
  ground_truth: Expected values including all expected tools
611
630
  predicted_tool_calls: All tool calls made by model across turns
612
631
  final_content: Final model response content
632
+ tools: List of available tools for this sample
613
633
 
614
634
  Returns:
615
635
  SampleEvaluation with computed metrics
@@ -652,9 +672,13 @@ class Evaluator:
652
672
  # Execution valid if we got through the conversation
653
673
  execution_valid = len(predicted_tool_calls) > 0 or final_content != ""
654
674
 
675
+ # Extract tool names for available_tools field
676
+ available_tool_names = [t.name for t in tools] if tools else []
677
+
655
678
  return SampleEvaluation(
656
679
  sample_id=sample_id,
657
680
  query=ground_truth.query,
681
+ available_tools=available_tool_names,
658
682
  expected_tool=ground_truth.expected_tool,
659
683
  predicted_tool=first_predicted_tool,
660
684
  expected_parameters=ground_truth.expected_parameters,
@@ -714,6 +738,7 @@ class Evaluator:
714
738
  ground_truth: GroundTruth,
715
739
  response: ModelResponse,
716
740
  evaluator_results: list[EvaluatorResult],
741
+ tools: list[ToolDefinition] | None = None,
717
742
  ) -> SampleEvaluation:
718
743
  """Aggregate evaluator results into SampleEvaluation.
719
744
 
@@ -722,6 +747,7 @@ class Evaluator:
722
747
  ground_truth: Expected values
723
748
  response: Model response
724
749
  evaluator_results: Results from all evaluators
750
+ tools: List of available tools for this sample
725
751
 
726
752
  Returns:
727
753
  SampleEvaluation with aggregated metrics
@@ -746,10 +772,14 @@ class Evaluator:
746
772
  params_correct = metrics.get("parameter_accuracy", 0.0) == 1.0
747
773
  execution_valid = metrics.get("execution_valid", 0.0) == 1.0
748
774
 
775
+ # Extract tool names for available_tools field
776
+ available_tool_names = [t.name for t in tools] if tools else []
777
+
749
778
  # Return backwards-compatible SampleEvaluation
750
779
  return SampleEvaluation(
751
780
  sample_id=sample_id,
752
781
  query=ground_truth.query,
782
+ available_tools=available_tool_names,
753
783
  expected_tool=ground_truth.expected_tool,
754
784
  predicted_tool=predicted_tool,
755
785
  expected_parameters=ground_truth.expected_parameters,
@@ -780,13 +810,17 @@ class Evaluator:
780
810
  console.print("[bold blue]Running evaluation...[/bold blue]")
781
811
  evaluations = []
782
812
 
783
- for idx, sample in tqdm(enumerate(samples), total=len(samples), desc="Evaluating"):
813
+ pbar = tqdm(enumerate(samples), total=len(samples), desc="Evaluating")
814
+ for idx, sample in pbar:
784
815
  eval_result = self.evaluate_sample(sample, idx)
785
816
  evaluations.append(eval_result)
786
817
 
787
818
  # Stream sample to reporters (for cloud real-time tracking)
788
819
  self.reporter.report_sample(eval_result)
789
820
 
821
+ # Force refresh for notebook compatibility
822
+ pbar.refresh()
823
+
790
824
  console.print("[bold green]Evaluation complete![/bold green]")
791
825
 
792
826
  # Compute metrics
@@ -804,7 +838,7 @@ class Evaluator:
804
838
  "evaluation_completed",
805
839
  {
806
840
  "backend": self.config.inference_config.backend,
807
- "model_path": self.config.inference_config.model_path,
841
+ "model": self.config.inference_config.model,
808
842
  "has_adapter": self.config.inference_config.adapter_path is not None,
809
843
  "samples_evaluated": metrics.samples_evaluated,
810
844
  "samples_processed": metrics.samples_processed,
@@ -63,14 +63,19 @@ class ToolCallingEvaluator(BaseEvaluator):
63
63
  # Compute metrics
64
64
  tool_correct = predicted_tool == ground_truth.expected_tool
65
65
 
66
- # Validate parameters against the PREDICTED tool (not expected)
67
- # This measures parameter extraction capability independently of tool selection
68
- params_correct = compare_parameters(
69
- ground_truth.expected_parameters,
70
- predicted_params,
71
- tool_name=predicted_tool, # Use predicted tool for schema validation
72
- tool_definitions=context.tools,
73
- )
66
+ # Parameter accuracy requires a tool to have been called
67
+ # If no tool was predicted but one was expected, params cannot be correct
68
+ if predicted_tool is None and ground_truth.expected_tool is not None:
69
+ params_correct = False
70
+ else:
71
+ # Validate parameters against the PREDICTED tool (not expected)
72
+ # This measures parameter extraction capability independently of tool selection
73
+ params_correct = compare_parameters(
74
+ ground_truth.expected_parameters,
75
+ predicted_params,
76
+ tool_name=predicted_tool, # Use predicted tool for schema validation
77
+ tool_definitions=context.tools,
78
+ )
74
79
 
75
80
  # Execution valid requires BOTH correct tool AND correct params
76
81
  execution_valid = tool_correct and params_correct
@@ -1,9 +1,9 @@
1
1
  """Model inference interfaces and implementations for evaluation."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import Literal
4
+ from typing import Any, Literal
5
5
 
6
- from pydantic import BaseModel, Field
6
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
7
7
 
8
8
  from ..schemas import ToolDefinition
9
9
 
@@ -11,17 +11,40 @@ from ..schemas import ToolDefinition
11
11
  class InferenceConfig(BaseModel):
12
12
  """Configuration for model inference."""
13
13
 
14
- model_path: str = Field(
15
- description="Path to model (local path or HuggingFace Hub ID)",
14
+ model_config = ConfigDict(arbitrary_types_allowed=True)
15
+
16
+ model: str | Any = Field(
17
+ description="Model identifier (local path, HuggingFace Hub ID, or model name for cloud providers). "
18
+ "Can also be a pre-loaded model object to avoid reloading.",
19
+ )
20
+ tokenizer: Any | None = Field(
21
+ default=None,
22
+ description="Pre-loaded tokenizer object. Required when model is a pre-loaded model object.",
16
23
  )
17
24
  adapter_path: str | None = Field(
18
25
  default=None,
19
26
  description="Path to PEFT/LoRA adapter (if using adapter-based fine-tuning)",
20
27
  )
21
- backend: Literal["transformers", "ollama"] = Field(
28
+ backend: Literal["transformers", "ollama", "llm"] = Field(
22
29
  default="transformers",
23
30
  description="Inference backend to use",
24
31
  )
32
+ provider: Literal["openai", "anthropic", "gemini", "openrouter"] | None = Field(
33
+ default=None,
34
+ description="Cloud LLM provider (required when backend='llm')",
35
+ )
36
+ api_key: str | None = Field(
37
+ default=None,
38
+ description="API key for the provider (falls back to environment variable if not set)",
39
+ )
40
+ base_url: str | None = Field(
41
+ default=None,
42
+ description="Custom base URL for the API (e.g., for OpenRouter or proxies)",
43
+ )
44
+ rate_limit_config: dict | None = Field(
45
+ default=None,
46
+ description="Rate limiting configuration overrides",
47
+ )
25
48
  use_unsloth: bool = Field(
26
49
  default=False,
27
50
  description="Use Unsloth for loading adapter (for adapters trained with Unsloth)",
@@ -62,6 +85,51 @@ class InferenceConfig(BaseModel):
62
85
  description="Batch size for inference",
63
86
  )
64
87
 
88
+ @field_serializer("model")
89
+ def serialize_model(self, value: str | Any) -> str:
90
+ """Serialize model field - convert objects to descriptive string."""
91
+ if isinstance(value, str):
92
+ return value
93
+ # For in-memory model objects, return a descriptive string
94
+ model_class = type(value).__name__
95
+ model_name = getattr(getattr(value, "config", None), "name_or_path", "unknown")
96
+ return f"<in-memory:{model_class}:{model_name}>"
97
+
98
+ @field_serializer("tokenizer")
99
+ def serialize_tokenizer(self, value: Any | None) -> str | None:
100
+ """Serialize tokenizer field - convert objects to descriptive string."""
101
+ if value is None:
102
+ return None
103
+ if isinstance(value, str):
104
+ return value
105
+ # For in-memory tokenizer objects, return a descriptive string
106
+ tokenizer_class = type(value).__name__
107
+ tokenizer_name = getattr(value, "name_or_path", "unknown")
108
+ return f"<in-memory:{tokenizer_class}:{tokenizer_name}>"
109
+
110
+ @model_validator(mode="after")
111
+ def validate_config(self) -> "InferenceConfig":
112
+ """Validate configuration consistency."""
113
+ # Ensure provider is set when using LLM backend
114
+ if self.backend == "llm" and self.provider is None:
115
+ msg = "provider must be specified when backend='llm'"
116
+ raise ValueError(msg)
117
+
118
+ # Check if model is a pre-loaded object (not a string path)
119
+ is_preloaded_model = not isinstance(self.model, str)
120
+
121
+ # If model is pre-loaded, tokenizer must also be provided
122
+ if is_preloaded_model and self.tokenizer is None:
123
+ msg = "tokenizer must be provided when using a pre-loaded model object"
124
+ raise ValueError(msg)
125
+
126
+ # Pre-loaded models only work with transformers backend
127
+ if is_preloaded_model and self.backend != "transformers":
128
+ msg = "pre-loaded model objects are only supported with backend='transformers'"
129
+ raise ValueError(msg)
130
+
131
+ return self
132
+
65
133
 
66
134
  class ModelResponse(BaseModel):
67
135
  """Model inference response."""
@@ -150,6 +218,10 @@ def create_inference_backend(config: InferenceConfig) -> InferenceBackend:
150
218
  from .backends.ollama_backend import OllamaBackend # noqa: PLC0415
151
219
 
152
220
  return OllamaBackend(config)
221
+ if config.backend == "llm":
222
+ from .backends.llm_eval_backend import LLMEvalBackend # noqa: PLC0415
223
+
224
+ return LLMEvalBackend(config)
153
225
 
154
226
  msg = f"Unsupported backend: {config.backend}"
155
227
  raise ValueError(msg)
@@ -107,6 +107,10 @@ class SampleEvaluation(BaseModel):
107
107
 
108
108
  sample_id: int = Field(description="Sample index")
109
109
  query: str = Field(description="Input query")
110
+ available_tools: list[str] = Field(
111
+ default_factory=list,
112
+ description="List of tool names available for this sample",
113
+ )
110
114
  expected_tool: str | None = Field(
111
115
  default=None,
112
116
  description="Expected tool name",
@@ -103,7 +103,7 @@ class CloudReporter(BaseReporter):
103
103
  run_data = {
104
104
  "project_id": self.project_id,
105
105
  "name": f"Evaluation - {datetime.now(UTC).strftime('%Y-%m-%d %H:%M')}",
106
- "model_name": result.config.inference_config.model_path,
106
+ "model_name": result.config.inference_config.model,
107
107
  "model_provider": result.config.inference_config.backend,
108
108
  "config": {
109
109
  "evaluators": getattr(result.config, "evaluators", ["tool_calling"]),
deepfabric/generator.py CHANGED
@@ -213,6 +213,10 @@ class DataSetGeneratorConfig(BaseModel):
213
213
  le=20,
214
214
  description="Minimum number of tool calls required before allowing early conversation conclusion",
215
215
  )
216
+ tool_inclusion_strategy: Literal["all", "used_only"] = Field(
217
+ default="used_only",
218
+ description="Which tools to include in each sample: 'all' includes full catalog, 'used_only' includes only tools actually called (recommended for training)",
219
+ )
216
220
 
217
221
 
218
222
  class DataSetGenerator:
deepfabric/hf_hub.py CHANGED
@@ -210,5 +210,5 @@ class HFUploader:
210
210
  else:
211
211
  return {
212
212
  "status": "success",
213
- "message": f"Dataset pushed successfully to {hf_dataset_repo}.",
213
+ "message": f"Dataset pushed successfully to https://huggingface.co/datasets/{hf_dataset_repo}",
214
214
  }
deepfabric/llm/client.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
+ import logging
2
3
  import os
4
+ import sys
3
5
 
4
6
  from functools import lru_cache
5
7
  from typing import Any
@@ -21,6 +23,8 @@ from .rate_limit_config import (
21
23
  )
22
24
  from .retry_handler import RetryHandler, retry_with_backoff, retry_with_backoff_async
23
25
 
26
+ logger = logging.getLogger(__name__)
27
+
24
28
  # JSON Schema union type keys that need recursive processing
25
29
  _UNION_KEYS = ("anyOf", "oneOf", "allOf")
26
30
 
@@ -1061,8 +1065,18 @@ def _get_cached_openai_schema(schema: type[BaseModel]) -> type[BaseModel]:
1061
1065
  OpenAICompatModel.__name__ = f"{schema.__name__}OpenAICompat"
1062
1066
  OpenAICompatModel.__doc__ = schema.__doc__
1063
1067
 
1064
- # Rebuild model to resolve forward references (e.g., PendingToolCall in AgentStep)
1065
- OpenAICompatModel.model_rebuild()
1068
+ # Rebuild model with the schema's original module namespace to resolve
1069
+ # forward references (e.g., PendingToolCall in AgentStep)
1070
+ schema_module = sys.modules.get(schema.__module__)
1071
+ if schema_module:
1072
+ OpenAICompatModel.model_rebuild(_types_namespace=vars(schema_module))
1073
+ else:
1074
+ logger.warning(
1075
+ "Could not find module '%s' in sys.modules. "
1076
+ "Forward reference resolution for dynamically created models may fail.",
1077
+ schema.__module__,
1078
+ )
1079
+ OpenAICompatModel.model_rebuild()
1066
1080
 
1067
1081
  return OpenAICompatModel
1068
1082
 
deepfabric/schemas.py CHANGED
@@ -304,8 +304,8 @@ class ToolDefinition(BaseModel):
304
304
  description = func.get("description", "")
305
305
  params_schema = func.get("parameters", {})
306
306
 
307
- properties = params_schema.get("properties", {})
308
- required_params = set(params_schema.get("required", []))
307
+ properties = params_schema.get("properties") or {}
308
+ required_params = set(params_schema.get("required") or [])
309
309
 
310
310
  parameters = []
311
311
  for param_name, param_props in properties.items():
@@ -1,20 +1,27 @@
1
- """DeepFabric training metrics logging.
1
+ """DeepFabric training utilities.
2
2
 
3
- This module provides integration with HuggingFace Trainer and TRL trainers
4
- to log training metrics to the DeepFabric SaaS backend.
3
+ This module provides:
4
+ - Integration with HuggingFace Trainer and TRL trainers for metrics logging
5
+ - Dataset preparation utilities for optimizing training data
5
6
 
6
7
  Features:
7
8
  - Non-blocking async metrics sending
8
9
  - Notebook-friendly API key prompts (like wandb)
9
10
  - Graceful handling of failures without impacting training
11
+ - Tool filtering to reduce sequence lengths and memory usage
10
12
 
11
13
  Usage:
12
- from deepfabric.training import DeepFabricCallback
14
+ from deepfabric.training import DeepFabricCallback, prepare_dataset_for_training
13
15
 
16
+ # Prepare dataset (reduces tool overhead)
17
+ dataset = load_dataset("your/dataset", split="train")
18
+ prepared = prepare_dataset_for_training(dataset, tool_strategy="used_only")
19
+
20
+ # Train with metrics logging
14
21
  trainer = Trainer(
15
22
  model=model,
16
23
  args=training_args,
17
- train_dataset=train_dataset,
24
+ train_dataset=prepared,
18
25
  )
19
26
  trainer.add_callback(DeepFabricCallback(trainer))
20
27
  trainer.train()
@@ -27,9 +34,21 @@ Environment Variables:
27
34
  from __future__ import annotations
28
35
 
29
36
  from .callback import DeepFabricCallback
37
+ from .dataset_utils import (
38
+ ToolInclusionStrategy,
39
+ clean_tool_schema,
40
+ filter_tools_for_sample,
41
+ get_used_tool_names,
42
+ prepare_dataset_for_training,
43
+ )
30
44
  from .metrics_sender import MetricsSender
31
45
 
32
46
  __all__ = [
33
47
  "DeepFabricCallback",
34
48
  "MetricsSender",
49
+ "ToolInclusionStrategy",
50
+ "clean_tool_schema",
51
+ "filter_tools_for_sample",
52
+ "get_used_tool_names",
53
+ "prepare_dataset_for_training",
35
54
  ]