tokendye 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: tokendye
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: pydantic
8
+ Requires-Dist: jsonl5>=0.1.1
9
+
10
+ A python package
@@ -0,0 +1 @@
1
+ A python package
@@ -0,0 +1,4 @@
1
+ from .dye_dataset import DyeDataset
2
+ from .functional import from_jsonl, from_jsonl5
3
+
4
+ __all__ = ["DyeDataset", "from_jsonl", "from_jsonl5"]
@@ -0,0 +1,122 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ if TYPE_CHECKING:
6
+ from ..dye_label import DyeLabel
7
+
8
+ def _build_sequence(
9
+ data: dict,
10
+ tokenizer,
11
+ labels: list['DyeLabel']
12
+ ):
13
+ input_ids = []
14
+ dye_mask = []
15
+
16
+ messages = [] # TODO: 优化
17
+ for segment in data["segments"]:
18
+ dye = segment["dye"]
19
+ text = segment["text"]
20
+ if dye == "system":
21
+ messages.append({"role": "system", "content": text})
22
+ elif dye == "user":
23
+ messages.append({"role": "user", "content": text})
24
+ elif dye == "file_text":
25
+ messages.append({"role": "user", "content": text})
26
+ elif dye == "tool_callback":
27
+ messages.append({"role": "tool", "content": text})
28
+ else:
29
+ messages.append({"role": "user", "content": text})
30
+
31
+ # tokenize=True直接返回ids,不含generation prompt
32
+ full_ids: list[int] = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=True,
35
+ add_generation_prompt=False,
36
+ )["input_ids"] # TODO: 留着attention_mask或许有用
37
+
38
+ # 2. 对每段content单独encode,在full_ids里定位边界,构建dye_mask
39
+ dye_mask = [-1] * len(full_ids)
40
+
41
+ label_map = {label.name:label.id for label in labels}
42
+ for segment in data["segments"]:
43
+ dye = segment["dye"]
44
+ text = segment["text"]
45
+ dye_id = label_map[dye]
46
+
47
+ content_ids = tokenizer.encode(text, add_special_tokens=False)
48
+ positions = _find_all_sublist(full_ids, content_ids)
49
+
50
+ if len(positions) == 0:
51
+ raise ValueError(f"content not found in full_ids: {text!r}")
52
+ if len(positions) > 1:
53
+ raise ValueError(
54
+ f"ambiguous match ({len(positions)} hits), "
55
+ f"two segments have identical content: {text!r}"
56
+ )
57
+
58
+ pos = positions[0]
59
+ for i in range(len(content_ids)):
60
+ dye_mask[pos + i] = dye_id
61
+
62
+ # 3. target部分:从apply_chat_template提取generation prompt
63
+ context_len = len(full_ids)
64
+
65
+ full_ids_with_gen: list[int] = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=True,
68
+ add_generation_prompt=True,
69
+ )["input_ids"] # TODO: 留着attention_mask或许有用
70
+ gen_prompt_ids = full_ids_with_gen[context_len:] # 只取新增部分
71
+
72
+ target_ids = tokenizer.encode(data["target"], add_special_tokens=False)
73
+ eos = [tokenizer.eos_token_id]
74
+
75
+ input_ids = full_ids + gen_prompt_ids + target_ids + eos
76
+ dye_mask = dye_mask + [-1] * (len(gen_prompt_ids) + len(target_ids) + 1)
77
+
78
+ target_mask = (
79
+ [False] * context_len
80
+ + [False] * len(gen_prompt_ids)
81
+ + [True] * len(target_ids)
82
+ + [True] # eos算loss
83
+ )
84
+
85
+ return input_ids, dye_mask, target_mask
86
+
87
+
88
+ class DyeDataset(Dataset):
89
+ def __init__(
90
+ self,
91
+ raw_data: list[dict],
92
+ tokenizer,
93
+ labels: list['DyeLabel']
94
+ ):
95
+ self.dataset: list[dict] = []
96
+
97
+ for data in raw_data:
98
+ input_ids, dye_mask, target_mask = _build_sequence(
99
+ data, tokenizer, labels
100
+ )
101
+
102
+ self.dataset.append(
103
+ {
104
+ "input_ids": input_ids,
105
+ "dye_mask": dye_mask,
106
+ "target_mask": target_mask,
107
+ }
108
+ )
109
+
110
+ def __len__(self):
111
+ return len(self.dataset)
112
+
113
+ def __getitem__(self, idx):
114
+ return self.dataset[idx]
115
+
116
+
117
+ def _find_all_sublist(full, sub) -> list[int]:
118
+ """返回sub在full中所有出现位置的起始index"""
119
+ n, m = len(full), len(sub)
120
+ if m == 0:
121
+ return []
122
+ return [i for i in range(n - m + 1) if full[i : i + m] == sub]
@@ -0,0 +1,28 @@
1
+ import json
2
+ from typing import TYPE_CHECKING
3
+
4
+ import jsonl5
5
+
6
+ from . import DyeDataset
7
+
8
+ if TYPE_CHECKING:
9
+ from os import PathLike
10
+
11
+ from ..dye_label import DyeLabel
12
+
13
+
14
+ def from_jsonl(data_path: "PathLike | str", tokenizer, labels: list["DyeLabel"]):
15
+ with open(data_path) as f:
16
+ # raw_data = []
17
+ # for line in f:
18
+ # raw_data.append(json.loads(line))
19
+ raw_data = [json.loads(line) for line in f]
20
+
21
+ return DyeDataset(raw_data, tokenizer, labels)
22
+
23
+
24
+ def from_jsonl5(data_path: "PathLike | str", tokenizer, labels: list["DyeLabel"]):
25
+ with open(data_path) as f:
26
+ raw_data = jsonl5.load(f)
27
+
28
+ return DyeDataset(raw_data, tokenizer, labels)
@@ -0,0 +1,10 @@
1
+ [project]
2
+ name = "tokendye"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "pydantic",
9
+ "jsonl5>=0.1.1",
10
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: tokendye
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: pydantic
8
+ Requires-Dist: jsonl5>=0.1.1
9
+
10
+ A python package
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ dataset/__init__.py
4
+ dataset/dye_dataset.py
5
+ dataset/functional.py
6
+ tokendye.egg-info/PKG-INFO
7
+ tokendye.egg-info/SOURCES.txt
8
+ tokendye.egg-info/dependency_links.txt
9
+ tokendye.egg-info/requires.txt
10
+ tokendye.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ pydantic
2
+ jsonl5>=0.1.1
@@ -0,0 +1 @@
1
+ dataset