sdg-hub 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdg_hub/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.0a2'
20
+ __version__ = version = '0.1.0a3'
21
21
  __version_tuple__ = version_tuple = (0, 1, 0)
sdg_hub/flow.py CHANGED
@@ -38,10 +38,37 @@ class Flow(ABC):
38
38
  self.base_path = str(resources.files(__package__))
39
39
  self.registered_blocks = BlockRegistry.get_registry()
40
40
 
41
+ def _getFilePath(self, dirs, filename):
42
+ """
43
+ Find a named configuration file.
44
+
45
+ Files are checked in the following order
46
+ - absulute path is always used
47
+ - checked relative to the directories in "dirs"
48
+ - relative the the current directory
49
+
50
+ Args:
51
+ dirs (list): Directories in which to search for "config_path"
52
+ config_path (str): The path to the configuration file.
53
+
54
+ Returns:
55
+ Selected file path
56
+ """
57
+ if os.path.isabs(filename):
58
+ return filename
59
+ for d in dirs:
60
+ full_file_path = os.path.join(d, filename)
61
+ if os.path.isfile(full_file_path):
62
+ return full_file_path
63
+ # If not found above then return the path unchanged i.e.
64
+ # assume the path is relative to the current directory
65
+ return filename
66
+
41
67
  def get_flow_from_file(self, yaml_path: str) -> list:
42
68
  yaml_path_relative_to_base = os.path.join(self.base_path, yaml_path)
43
69
  if os.path.isfile(yaml_path_relative_to_base):
44
70
  yaml_path = yaml_path_relative_to_base
71
+ yaml_dir = os.path.dirname(yaml_path)
45
72
 
46
73
  try:
47
74
  with open(yaml_path, "r", encoding="utf-8") as yaml_file:
@@ -86,33 +113,23 @@ class Flow(ABC):
86
113
 
87
114
  # update config path to absolute path
88
115
  if "config_path" in block["block_config"]:
89
- config_path_relative_to_base = os.path.join(
90
- self.base_path, block["block_config"]["config_path"]
116
+ block["block_config"]["config_path"] = self._getFilePath(
117
+ [yaml_dir, self.base_path], block["block_config"]["config_path"]
91
118
  )
92
- if os.path.isfile(config_path_relative_to_base):
93
- block["block_config"]["config_path"] = config_path_relative_to_base
94
119
 
95
120
  # update config paths to absolute paths - this might be a list or a dict
96
121
  if "config_paths" in block["block_config"]:
97
122
  if isinstance(block["block_config"]["config_paths"], dict):
98
123
  for key, path in block["block_config"]["config_paths"].items():
99
- config_path_relative_to_base = os.path.join(
100
- self.base_path, path
124
+ block["block_config"]["config_paths"][key] = self._getFilePath(
125
+ [yaml_dir, self.base_path], path
101
126
  )
102
- if os.path.isfile(config_path_relative_to_base):
103
- block["block_config"]["config_paths"][key] = (
104
- config_path_relative_to_base
105
- )
106
127
 
107
- if isinstance(block["block_config"]["config_paths"], list):
128
+ elif isinstance(block["block_config"]["config_paths"], list):
108
129
  for i, path in enumerate(block["block_config"]["config_paths"]):
109
- config_path_relative_to_base = os.path.join(
110
- self.base_path, path
130
+ block["block_config"]["config_paths"][i] = self._getFilePath(
131
+ [yaml_dir, self.base_path], path
111
132
  )
112
- if os.path.isfile(config_path_relative_to_base):
113
- block["block_config"]["config_paths"][i] = (
114
- config_path_relative_to_base
115
- )
116
133
 
117
134
  if "operation" in block["block_config"]:
118
135
  block["block_config"]["operation"] = OPERATOR_MAP[
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdg_hub
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a3
4
4
  Summary: Synthetic Data Generation
5
5
  Author-email: Red Hat AI Innovation <abhandwa@redhat.com>
6
6
  License: Apache-2.0
@@ -1,6 +1,6 @@
1
1
  sdg_hub/__init__.py,sha256=5Wa6onDndPvG4iwnjq2jK747t3-7XKdQn2WfHfq1sFc,67
2
- sdg_hub/_version.py,sha256=ypYjg1lOKm-KjhEwqwSB8XPrSuAfnfK8hSDDWtD6VoA,513
3
- sdg_hub/flow.py,sha256=t03SrX2AHHhzt5ebOAg2TwuZasNerrucanjb5_6FnOk,5025
2
+ sdg_hub/_version.py,sha256=wrhrM1UZdxROWn7XOHbbPZa5jOBzV8tlSBMw233huBg,513
3
+ sdg_hub/flow.py,sha256=3b97fMei1rWuQWeNfv-xyHKUbcMaf-d_b9Xms9J3BCQ,5425
4
4
  sdg_hub/logger_config.py,sha256=7uHEJVRfym1c4n95DOKHelLXqAus8uHsZYmzLsEjqpo,422
5
5
  sdg_hub/pipeline.py,sha256=u24ccryfy_nOSvsrWiynNmq1rOmOOkw1L5-TqJvuRSo,2339
6
6
  sdg_hub/prompts.py,sha256=dOiC9CsNbMt5Km9PnwyuW0v9zUs3cVXE5jZYwtXZTwc,1957
@@ -80,15 +80,11 @@ sdg_hub/flows/generation/skills/synth_grounded_skills.yaml,sha256=91Dm--agpmbm02
80
80
  sdg_hub/flows/generation/skills/synth_skills.yaml,sha256=PhUP2iBo4RkeFafSW-qxh4WmX_ZTfGi0UAmwN_XSTqs,1504
81
81
  sdg_hub/utils/__init__.py,sha256=UEo-9qPt5iVKBIRvgZhOI0SoIBO6zeBxOuLvUQXaM3g,185
82
82
  sdg_hub/utils/chunking.py,sha256=VSPQ8dSFI5LF4sefcI0tzWG0Vc1rM_FSMTO6xg_iFzA,2556
83
- sdg_hub/utils/datamixing.py,sha256=nkjyRY3AnOYYJAU4cyVk17XQwr9hUC6DicU_tuK8O8I,4099
84
83
  sdg_hub/utils/datautils.py,sha256=0t_SZ_UXBKl8uL6rVp3SUh8YKRbzKlh2oO5gr2cKyEw,389
85
84
  sdg_hub/utils/docprocessor.py,sha256=Z4J2DfLhRxMCeIeMKttwi-FdivmPqI-hjEwq6-Ub35c,12485
86
- sdg_hub/utils/json.py,sha256=Ub6OzSYu8PqUHki9RXxuwLkW9RSqY_DBWA5qiyLE_PA,1378
87
- sdg_hub/utils/models.py,sha256=mp6J1cPcEUYvr6AvRAflsnY0E5moLcmHj3MgnR0FPE4,894
88
85
  sdg_hub/utils/parse_and_convert.py,sha256=I27FdS-H2mSoZ07SsKZmNYM2F_Cg7GHTBXD7YNgASNw,13443
89
- sdg_hub/utils/taxonomy.py,sha256=WmO4oR4O1o1I9Yr7urfMNP80nS-p2z5bhgvDt0cg4mE,16947
90
- sdg_hub-0.1.0a2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
91
- sdg_hub-0.1.0a2.dist-info/METADATA,sha256=ol0u_UumQKw1VVdatPV0OW0-J7bPx_EKdqn1O7scJoo,5847
92
- sdg_hub-0.1.0a2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
93
- sdg_hub-0.1.0a2.dist-info/top_level.txt,sha256=TqI7d-HE1n6zkXFkU0nF3A1Ct0P0pBaqI675uFokhx4,8
94
- sdg_hub-0.1.0a2.dist-info/RECORD,,
86
+ sdg_hub-0.1.0a3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
87
+ sdg_hub-0.1.0a3.dist-info/METADATA,sha256=vUusH0jLACOcoxvTL-e5dAPfhoTV--zgs_MJ-6IYQfQ,5847
88
+ sdg_hub-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
89
+ sdg_hub-0.1.0a3.dist-info/top_level.txt,sha256=TqI7d-HE1n6zkXFkU0nF3A1Ct0P0pBaqI675uFokhx4,8
90
+ sdg_hub-0.1.0a3.dist-info/RECORD,,
@@ -1,123 +0,0 @@
1
- # Standard
2
- import json
3
- import os
4
-
5
- # Third Party
6
- from datasets import Dataset, load_dataset
7
- import yaml
8
-
9
- # First Party
10
- from sdg_hub.logger_config import setup_logger
11
- from .datautils import safe_concatenate_datasets
12
-
13
-
14
- LOGGER = setup_logger(__name__)
15
- ALLOWED_COLS = ["id", "messages", "metadata"]
16
-
17
-
18
- def adjust_train_sample_size(ds: Dataset, num_samples: int):
19
- LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...")
20
- df = ds.to_pandas()
21
- df = df.sample(n=num_samples, random_state=42, replace=True).reset_index(drop=True)
22
- return Dataset.from_pandas(df)
23
-
24
-
25
- def load_ds(path, sampling_size):
26
- if path.endswith(".jsonl"):
27
- LOGGER.info(f"Loading dataset from {path} ...")
28
- dataset = load_dataset("json", data_files=path, split="train")
29
- else:
30
- LOGGER.info(f"Loading dataset from HF {path} ...")
31
- dataset = load_dataset(path, split="train")
32
- LOGGER.info(f"Dataset columns: {dataset.column_names}")
33
- LOGGER.info(f"Dataset loaded with {len(dataset)} samples")
34
-
35
- if sampling_size != 1.0:
36
- if isinstance(sampling_size, int):
37
- num_samples = sampling_size
38
- else:
39
- num_samples = int(len(dataset) * sampling_size)
40
- dataset = adjust_train_sample_size(dataset, num_samples)
41
-
42
- # move any column that is not in ALLOWED_COLS to metadata
43
- def move_unallowed_cols_to_metadata(example):
44
- metadata = example.get("metadata", {})
45
- if isinstance(metadata, str):
46
- metadata = json.loads(metadata)
47
- for col in dataset.column_names:
48
- if col not in ALLOWED_COLS:
49
- metadata[col] = example[col]
50
- example.pop(col)
51
- example["metadata"] = json.dumps(metadata)
52
- return example
53
-
54
- dataset = dataset.map(move_unallowed_cols_to_metadata, num_proc=8)
55
-
56
- # check if metadata column is string if not convert it using json.dumps
57
- if not isinstance(dataset["metadata"][0], str):
58
- dataset = dataset.map(
59
- lambda x: {"metadata": json.dumps(x["metadata"])}, num_proc=8
60
- )
61
-
62
- return dataset
63
-
64
-
65
- def add_system_message(sample: dict, sys_prompt: str) -> dict:
66
- # check if the messages have role system
67
- has_system = False
68
- for msg in sample["messages"]:
69
- if msg["role"] == "system":
70
- has_system = True
71
- msg["content"] = sys_prompt
72
-
73
- if not has_system:
74
- sample["messages"].insert(0, {"role": "system", "content": sys_prompt})
75
-
76
- return sample
77
-
78
-
79
- class Recipe:
80
- def __init__(self, recipe_path):
81
- self.recipe_path = recipe_path
82
- self.recipe = self._load_recipe()
83
- self.sys_prompt = self.recipe.get("sys_prompt", "")
84
- self.dataset_added = False
85
-
86
- def _load_recipe(self):
87
- with open(self.recipe_path, encoding="utf-8") as fp:
88
- return yaml.safe_load(fp)
89
-
90
- def add_dataset(self, path, sampling_size=1.0):
91
- self.dataset_added = True
92
- self.recipe["datasets"].append({"path": path, "sampling_size": sampling_size})
93
-
94
- def save_recipe(self, output_path):
95
- # check if directory exists
96
- output_dir = os.path.dirname(output_path)
97
- if not os.path.exists(output_dir):
98
- os.makedirs(output_dir)
99
-
100
- with open(output_path, "w", encoding="utf-8") as fp:
101
- yaml.dump(self.recipe, fp)
102
-
103
- def save_mixed_dataset(self, output_path):
104
- if not self.dataset_added:
105
- LOGGER.error("No dataset added to the recipe")
106
-
107
- mixed_ds = [
108
- load_ds(dataset["path"], dataset["sampling_size"])
109
- for dataset in self.recipe["datasets"]
110
- ]
111
-
112
- mixed_ds = safe_concatenate_datasets(mixed_ds)
113
- mixed_ds = mixed_ds.map(
114
- add_system_message, fn_kwargs={"sys_prompt": self.sys_prompt}, num_proc=8
115
- )
116
-
117
- # assert that the dataset only has the allowed columns
118
- assert set(mixed_ds.column_names) == set(
119
- ALLOWED_COLS
120
- ), "Dataset has invalid columns"
121
-
122
- mixed_ds.to_json(output_path, orient="records", lines=True)
123
- LOGGER.info(f"Mixed Dataset saved to {output_path}")
sdg_hub/utils/json.py DELETED
@@ -1,48 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
-
3
- # Standard
4
- import io
5
- import json
6
- import os
7
-
8
-
9
- def _make_w_io_base(f, mode: str):
10
- # pylint: disable=consider-using-with
11
- if not isinstance(f, io.IOBase):
12
- f_dirname = os.path.dirname(f)
13
- if f_dirname != "":
14
- os.makedirs(f_dirname, exist_ok=True)
15
- f = open(f, mode=mode, encoding="utf-8")
16
- return f
17
-
18
-
19
- def _make_r_io_base(f, mode: str):
20
- # pylint: disable=consider-using-with
21
- if not isinstance(f, io.IOBase):
22
- f = open(f, mode=mode, encoding="utf-8")
23
- return f
24
-
25
-
26
- def jdump(obj, f, mode="w", indent=4, default=str):
27
- """Dump a str or dictionary to a file in json format.
28
-
29
- Args:
30
- obj: An object to be written.
31
- f: A string path to the location on disk.
32
- mode: Mode for opening the file.
33
- indent: Indent for storing json dictionaries.
34
- default: A function to handle non-serializable entries; defaults to `str`.
35
- """
36
- with _make_w_io_base(f, mode) as f_:
37
- if isinstance(obj, (dict, list)):
38
- json.dump(obj, f_, indent=indent, default=default)
39
- elif isinstance(obj, str):
40
- f_.write(obj)
41
- else:
42
- raise ValueError(f"Unexpected type: {type(obj)}")
43
-
44
-
45
- def jload(f, mode="r"):
46
- """Load a .json file into a dictionary."""
47
- with _make_r_io_base(f, mode) as f_:
48
- return json.load(f_)
sdg_hub/utils/models.py DELETED
@@ -1,31 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
-
3
- # Standard
4
- import os
5
- import re
6
-
7
- # First Party
8
- from sdg_hub.utils import GenerateException
9
-
10
- # When otherwise unknown, ilab uses this as the default family
11
- DEFAULT_MODEL_FAMILY = "merlinite"
12
-
13
- # Model families understood by ilab
14
- MODEL_FAMILIES = set(("merlinite", "mixtral"))
15
-
16
- # Map model names to their family
17
- MODEL_FAMILY_MAPPINGS = {
18
- "granite": "merlinite",
19
- }
20
-
21
-
22
- def get_model_family(forced, model_path):
23
- forced = MODEL_FAMILY_MAPPINGS.get(forced, forced)
24
- if forced and forced.lower() not in MODEL_FAMILIES:
25
- raise GenerateException("Unknown model family: %s" % forced)
26
-
27
- # Try to guess the model family based on the model's filename
28
- guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
29
- guess = MODEL_FAMILY_MAPPINGS.get(guess, guess)
30
-
31
- return guess if guess in MODEL_FAMILIES else DEFAULT_MODEL_FAMILY
sdg_hub/utils/taxonomy.py DELETED
@@ -1,489 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
-
3
- # Standard
4
- from functools import cache
5
- from pathlib import Path
6
- from typing import Any, Dict, List, Mapping, Optional, Union
7
- import glob
8
- import json
9
- import logging
10
- import os
11
- import re
12
- import subprocess
13
- import tempfile
14
-
15
- # Third Party
16
- import git
17
- import gitdb
18
- import yaml
19
-
20
- # First Party
21
- from sdg_hub import utils
22
- from sdg_hub.utils import chunking
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- DEFAULT_YAML_RULES = """\
27
- extends: relaxed
28
-
29
- rules:
30
- line-length:
31
- max: 120
32
- """
33
-
34
-
35
- class TaxonomyReadingException(Exception):
36
- """An exception raised during reading of the taxonomy."""
37
-
38
-
39
- TAXONOMY_FOLDERS: List[str] = ["compositional_skills", "knowledge"]
40
- """Taxonomy folders which are also the schema names"""
41
-
42
-
43
- def _istaxonomyfile(fn):
44
- path = Path(fn)
45
- if path.suffix == ".yaml" and path.parts[0] in TAXONOMY_FOLDERS:
46
- return True
47
- return False
48
-
49
-
50
- def _get_taxonomy_diff(repo="taxonomy", base="origin/main"):
51
- repo = git.Repo(repo)
52
- untracked_files = [u for u in repo.untracked_files if _istaxonomyfile(u)]
53
-
54
- branches = [b.name for b in repo.branches]
55
-
56
- head_commit = None
57
- if "/" in base:
58
- re_git_branch = re.compile(f"remotes/{base}$", re.MULTILINE)
59
- elif base in branches:
60
- re_git_branch = re.compile(f"{base}$", re.MULTILINE)
61
- else:
62
- try:
63
- head_commit = repo.commit(base)
64
- except gitdb.exc.BadName as e:
65
- raise SystemExit(
66
- yaml.YAMLError(
67
- f'Couldn\'t find the taxonomy git ref "{base}" from the current HEAD'
68
- )
69
- ) from e
70
-
71
- # Move backwards from HEAD until we find the first commit that is part of base
72
- # then we can take our diff from there
73
- current_commit = repo.commit("HEAD")
74
- while not head_commit:
75
- branches = repo.git.branch("-a", "--contains", current_commit.hexsha)
76
- if re_git_branch.findall(branches):
77
- head_commit = current_commit
78
- break
79
- try:
80
- current_commit = current_commit.parents[0]
81
- except IndexError as e:
82
- raise SystemExit(
83
- yaml.YAMLError(
84
- f'Couldn\'t find the taxonomy base branch "{base}" from the current HEAD'
85
- )
86
- ) from e
87
-
88
- modified_files = [
89
- d.b_path
90
- for d in head_commit.diff(None)
91
- if not d.deleted_file and _istaxonomyfile(d.b_path)
92
- ]
93
-
94
- updated_taxonomy_files = list(set(untracked_files + modified_files))
95
- return updated_taxonomy_files
96
-
97
-
98
- def _get_documents(
99
- source: Dict[str, Union[str, List[str]]],
100
- skip_checkout: bool = False,
101
- ) -> List[str]:
102
- """
103
- Retrieve the content of files from a Git repository.
104
-
105
- Args:
106
- source (dict): Source info containing repository URL, commit hash, and list of file patterns.
107
-
108
- Returns:
109
- List[str]: List of document contents.
110
- """ ""
111
- repo_url = source.get("repo")
112
- commit_hash = source.get("commit")
113
- file_patterns = source.get("patterns", [])
114
- with tempfile.TemporaryDirectory() as temp_dir:
115
- try:
116
- repo = git.Repo.clone_from(repo_url, temp_dir)
117
- if not skip_checkout:
118
- repo.git.checkout(commit_hash)
119
-
120
- file_contents = []
121
-
122
- logger.debug("Processing files...")
123
- for pattern in file_patterns:
124
- for file_path in glob.glob(os.path.join(repo.working_dir, pattern)):
125
- if os.path.isfile(file_path) and file_path.endswith(".md"):
126
- with open(file_path, "r", encoding="utf-8") as file:
127
- file_contents.append(file.read())
128
-
129
- if file_contents:
130
- return file_contents
131
- raise SystemExit("Couldn't find knowledge documents")
132
- except (OSError, git.exc.GitCommandError, FileNotFoundError) as e:
133
- raise e
134
-
135
-
136
- @cache
137
- def _load_schema(path: "importlib.resources.abc.Traversable") -> "referencing.Resource":
138
- """Load the schema from the path into a Resource object.
139
-
140
- Args:
141
- path (Traversable): Path to the schema to be loaded.
142
-
143
- Raises:
144
- NoSuchResource: If the resource cannot be loaded.
145
-
146
- Returns:
147
- Resource: A Resource containing the requested schema.
148
- """
149
- # pylint: disable=C0415
150
- # Third Party
151
- from referencing import Resource
152
- from referencing.exceptions import NoSuchResource
153
- from referencing.jsonschema import DRAFT202012
154
-
155
- try:
156
- contents = json.loads(path.read_text(encoding="utf-8"))
157
- resource = Resource.from_contents(
158
- contents=contents, default_specification=DRAFT202012
159
- )
160
- except Exception as e:
161
- raise NoSuchResource(ref=str(path)) from e
162
- return resource
163
-
164
-
165
- def _validate_yaml(contents: Mapping[str, Any], taxonomy_path: Path) -> int:
166
- """Validate the parsed yaml document using the taxonomy path to
167
- determine the proper schema.
168
-
169
- Args:
170
- contents (Mapping): The parsed yaml document to validate against the schema.
171
- taxonomy_path (Path): Relative path of the taxonomy yaml document where the
172
- first element is the schema to use.
173
-
174
- Returns:
175
- int: The number of errors found during validation.
176
- Messages for each error have been logged.
177
- """
178
- # pylint: disable=C0415
179
- # Standard
180
- from importlib import resources
181
-
182
- # Third Party
183
- from jsonschema.protocols import Validator
184
- from jsonschema.validators import validator_for
185
- from referencing import Registry, Resource
186
- from referencing.exceptions import NoSuchResource
187
- from referencing.typing import URI
188
-
189
- errors = 0
190
- version = _get_version(contents)
191
- schemas_path = resources.files("instructlab.schema").joinpath(f"v{version}")
192
-
193
- def retrieve(uri: URI) -> Resource:
194
- path = schemas_path.joinpath(uri)
195
- return _load_schema(path)
196
-
197
- schema_name = taxonomy_path.parts[0]
198
- if schema_name not in TAXONOMY_FOLDERS:
199
- schema_name = "knowledge" if "document" in contents else "compositional_skills"
200
- logger.info(
201
- f"Cannot determine schema name from path {taxonomy_path}. Using {schema_name} schema."
202
- )
203
-
204
- try:
205
- schema_resource = retrieve(f"{schema_name}.json")
206
- schema = schema_resource.contents
207
- validator_cls = validator_for(schema)
208
- validator: Validator = validator_cls(
209
- schema, registry=Registry(retrieve=retrieve)
210
- )
211
-
212
- for validation_error in validator.iter_errors(contents):
213
- errors += 1
214
- yaml_path = validation_error.json_path[1:]
215
- if not yaml_path:
216
- yaml_path = "."
217
- if validation_error.validator == "minItems":
218
- # Special handling for minItems which can have a long message for seed_examples
219
- message = (
220
- f"Value must have at least {validation_error.validator_value} items"
221
- )
222
- else:
223
- message = validation_error.message[-200:]
224
- logger.error(
225
- f"Validation error in {taxonomy_path}: [{yaml_path}] {message}"
226
- )
227
- except NoSuchResource as e:
228
- cause = e.__cause__ if e.__cause__ is not None else e
229
- errors += 1
230
- logger.error(f"Cannot load schema file {e.ref}. {cause}")
231
-
232
- return errors
233
-
234
-
235
- def _get_version(contents: Mapping) -> int:
236
- version = contents.get("version", 1)
237
- if not isinstance(version, int):
238
- # schema validation will complain about the type
239
- try:
240
- version = int(version)
241
- except ValueError:
242
- version = 1 # fallback to version 1
243
- return version
244
-
245
-
246
- # pylint: disable=broad-exception-caught
247
- def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None):
248
- seed_instruction_data = []
249
- warnings = 0
250
- errors = 0
251
- file_path = Path(file_path).resolve()
252
- # file should end with ".yaml" explicitly
253
- if file_path.suffix != ".yaml":
254
- logger.warning(
255
- f"Skipping {file_path}! Use lowercase '.yaml' extension instead."
256
- )
257
- warnings += 1
258
- return None, warnings, errors
259
- for i in range(len(file_path.parts) - 1, -1, -1):
260
- if file_path.parts[i] in TAXONOMY_FOLDERS:
261
- taxonomy_path = Path(*file_path.parts[i:])
262
- break
263
- else:
264
- taxonomy_path = file_path
265
- # read file if extension is correct
266
- try:
267
- with open(file_path, "r", encoding="utf-8") as file:
268
- contents = yaml.safe_load(file)
269
- if not contents:
270
- logger.warning(f"Skipping {file_path} because it is empty!")
271
- warnings += 1
272
- return None, warnings, errors
273
- if not isinstance(contents, Mapping):
274
- logger.error(
275
- f"{file_path} is not valid. The top-level element is not an object with key-value pairs."
276
- )
277
- errors += 1
278
- return None, warnings, errors
279
-
280
- # do general YAML linting if specified
281
- version = _get_version(contents)
282
- if version > 1: # no linting for version 1 yaml
283
- if yaml_rules is not None:
284
- is_file = os.path.isfile(yaml_rules)
285
- if is_file:
286
- logger.debug(f"Using YAML rules from {yaml_rules}")
287
- yamllint_cmd = [
288
- "yamllint",
289
- "-f",
290
- "parsable",
291
- "-c",
292
- yaml_rules,
293
- file_path,
294
- "-s",
295
- ]
296
- else:
297
- logger.debug(f"Cannot find {yaml_rules}. Using default rules.")
298
- yamllint_cmd = [
299
- "yamllint",
300
- "-f",
301
- "parsable",
302
- "-d",
303
- DEFAULT_YAML_RULES,
304
- file_path,
305
- "-s",
306
- ]
307
- else:
308
- yamllint_cmd = [
309
- "yamllint",
310
- "-f",
311
- "parsable",
312
- "-d",
313
- DEFAULT_YAML_RULES,
314
- file_path,
315
- "-s",
316
- ]
317
- try:
318
- subprocess.check_output(yamllint_cmd, text=True)
319
- except subprocess.SubprocessError as e:
320
- lint_messages = [f"Problems found in file {file_path}"]
321
- parsed_output = e.output.splitlines()
322
- for p in parsed_output:
323
- errors += 1
324
- delim = str(file_path) + ":"
325
- parsed_p = p.split(delim)[1]
326
- lint_messages.append(parsed_p)
327
- logger.error("\n".join(lint_messages))
328
- return None, warnings, errors
329
-
330
- # validation_errors = _validate_yaml(contents, taxonomy_path)
331
- # if validation_errors:
332
- # errors += validation_errors
333
- # return None, warnings, errors
334
-
335
- # get seed instruction data
336
- tax_path = "->".join(taxonomy_path.parent.parts)
337
- task_description = contents.get("task_description", None)
338
- domain = contents.get("domain")
339
- documents = contents.get("document")
340
- if documents:
341
- documents = _get_documents(source=documents)
342
- logger.debug("Content from git repo fetched")
343
-
344
- for seed_example in contents.get("seed_examples"):
345
- context = seed_example.get("context", "")
346
- if 'questions_and_answers' in seed_example:
347
- question_answer_list = seed_example.get("questions_and_answers")
348
- seed_instruction_data.append(
349
- {
350
- "questions_and_answers": question_answer_list,
351
- "input": context,
352
- "taxonomy_path": tax_path,
353
- "document": documents,
354
- "domain": domain,
355
- "document_outline": contents.get("document_outline")
356
- }
357
- )
358
- else:
359
- question = seed_example.get("question")
360
- answer = seed_example.get("answer")
361
-
362
- seed_instruction_data.append(
363
- {
364
- "instruction": question,
365
- "input": context,
366
- "output": answer,
367
- "taxonomy_path": tax_path,
368
- "task_description": task_description,
369
- "document": documents,
370
- "domain": domain,
371
- }
372
- )
373
- except Exception as e:
374
- errors += 1
375
- raise TaxonomyReadingException(f"Exception {e} raised in {file_path}") from e
376
-
377
- return seed_instruction_data, warnings, errors
378
-
379
-
380
- def read_taxonomy(taxonomy, taxonomy_base, yaml_rules):
381
- seed_instruction_data = []
382
- is_file = os.path.isfile(taxonomy)
383
- if is_file: # taxonomy is file
384
- seed_instruction_data, warnings, errors = _read_taxonomy_file(
385
- taxonomy, yaml_rules
386
- )
387
- if warnings:
388
- logger.warning(
389
- f"{warnings} warnings (see above) due to taxonomy file not (fully) usable."
390
- )
391
- if errors:
392
- raise SystemExit(yaml.YAMLError("Taxonomy file with errors! Exiting."))
393
- else: # taxonomy is dir
394
- # Gather the new or changed YAMLs using git diff
395
- updated_taxonomy_files = _get_taxonomy_diff(taxonomy, taxonomy_base)
396
- total_errors = 0
397
- total_warnings = 0
398
- if updated_taxonomy_files:
399
- logger.debug("Found new taxonomy files:")
400
- for e in updated_taxonomy_files:
401
- logger.debug(f"* {e}")
402
- for f in updated_taxonomy_files:
403
- file_path = os.path.join(taxonomy, f)
404
- data, warnings, errors = _read_taxonomy_file(file_path, yaml_rules)
405
- total_warnings += warnings
406
- total_errors += errors
407
- if data:
408
- seed_instruction_data.extend(data)
409
- if total_warnings:
410
- logger.warning(
411
- f"{total_warnings} warnings (see above) due to taxonomy files that were not (fully) usable."
412
- )
413
- if total_errors:
414
- raise SystemExit(
415
- yaml.YAMLError(f"{total_errors} taxonomy files with errors! Exiting.")
416
- )
417
- return seed_instruction_data
418
-
419
-
420
- def read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules):
421
- seed_instruction_data = read_taxonomy(taxonomy, taxonomy_base, yaml_rules)
422
-
423
- # Transform into a more convenient format to feed into our updated SDG library
424
- leaf_nodes = {}
425
- for seed in seed_instruction_data:
426
- node = leaf_nodes.setdefault(seed["taxonomy_path"], [])
427
- node.append(seed)
428
- leaf_nodes[seed["taxonomy_path"]] = node
429
-
430
- return leaf_nodes
431
-
432
-
433
- def _knowledge_leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count):
434
- samples = []
435
- # document is the same for the whole leaf node
436
- chunks = (
437
- chunking.chunk_document(
438
- documents=leaf_node[0]["document"],
439
- server_ctx_size=server_ctx_size,
440
- chunk_word_count=chunk_word_count,
441
- )
442
- if leaf_node[0].get("document")
443
- else []
444
- )
445
-
446
- # domain is the same for the whole leaf node
447
- domain = leaf_node[0].get("domain")
448
-
449
- for chunk in chunks:
450
- # pylint: disable=consider-using-enumerate
451
- for icl_ in leaf_node:
452
- icl_query = {f"icl_query_{idx+1}": val["question"] for idx, val in enumerate(icl_["questions_and_answers"])}
453
- icl_resp = {f"icl_response_{idx+1}": val["answer"] for idx, val in enumerate(icl_["questions_and_answers"])}
454
- samples_row = {
455
- "icl_document": icl_["input"],
456
- "document": chunk,
457
- "document_outline": icl_["document_outline"],
458
- "domain": domain
459
- }
460
- samples_row.update(icl_query)
461
- samples_row.update(icl_resp)
462
- samples.append(samples_row)
463
-
464
- return samples
465
-
466
-
467
- def _skill_leaf_node_to_samples(leaf_node):
468
- samples = []
469
-
470
- # pylint: disable=consider-using-enumerate
471
- for i in range(len(leaf_node)):
472
- samples.append({})
473
- samples[-1]["task_description"] = leaf_node[i]["task_description"]
474
- if leaf_node[i].get("input"):
475
- samples[-1]["seed_context"] = leaf_node[i]["input"]
476
- samples[-1]["seed_question"] = leaf_node[i]["instruction"]
477
- samples[-1]["seed_response"] = leaf_node[i]["output"]
478
-
479
- return samples
480
-
481
-
482
- def leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count):
483
- if not leaf_node:
484
- return []
485
- if leaf_node[0].get("document"):
486
- return _knowledge_leaf_node_to_samples(
487
- leaf_node, server_ctx_size, chunk_word_count
488
- )
489
- return _skill_leaf_node_to_samples(leaf_node)