tokendye 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokendye-0.1.0/PKG-INFO +10 -0
- tokendye-0.1.0/README.md +1 -0
- tokendye-0.1.0/dataset/__init__.py +4 -0
- tokendye-0.1.0/dataset/dye_dataset.py +122 -0
- tokendye-0.1.0/dataset/functional.py +28 -0
- tokendye-0.1.0/pyproject.toml +10 -0
- tokendye-0.1.0/setup.cfg +4 -0
- tokendye-0.1.0/tokendye.egg-info/PKG-INFO +10 -0
- tokendye-0.1.0/tokendye.egg-info/SOURCES.txt +10 -0
- tokendye-0.1.0/tokendye.egg-info/dependency_links.txt +1 -0
- tokendye-0.1.0/tokendye.egg-info/requires.txt +2 -0
- tokendye-0.1.0/tokendye.egg-info/top_level.txt +1 -0
tokendye-0.1.0/PKG-INFO
ADDED
tokendye-0.1.0/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
A python package
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from ..dye_label import DyeLabel
|
|
7
|
+
|
|
8
|
+
def _build_sequence(
|
|
9
|
+
data: dict,
|
|
10
|
+
tokenizer,
|
|
11
|
+
labels: list['DyeLabel']
|
|
12
|
+
):
|
|
13
|
+
input_ids = []
|
|
14
|
+
dye_mask = []
|
|
15
|
+
|
|
16
|
+
messages = [] # TODO: 优化
|
|
17
|
+
for segment in data["segments"]:
|
|
18
|
+
dye = segment["dye"]
|
|
19
|
+
text = segment["text"]
|
|
20
|
+
if dye == "system":
|
|
21
|
+
messages.append({"role": "system", "content": text})
|
|
22
|
+
elif dye == "user":
|
|
23
|
+
messages.append({"role": "user", "content": text})
|
|
24
|
+
elif dye == "file_text":
|
|
25
|
+
messages.append({"role": "user", "content": text})
|
|
26
|
+
elif dye == "tool_callback":
|
|
27
|
+
messages.append({"role": "tool", "content": text})
|
|
28
|
+
else:
|
|
29
|
+
messages.append({"role": "user", "content": text})
|
|
30
|
+
|
|
31
|
+
# tokenize=True直接返回ids,不含generation prompt
|
|
32
|
+
full_ids: list[int] = tokenizer.apply_chat_template(
|
|
33
|
+
messages,
|
|
34
|
+
tokenize=True,
|
|
35
|
+
add_generation_prompt=False,
|
|
36
|
+
)["input_ids"] # TODO: 留着attention_mask或许有用
|
|
37
|
+
|
|
38
|
+
# 2. 对每段content单独encode,在full_ids里定位边界,构建dye_mask
|
|
39
|
+
dye_mask = [-1] * len(full_ids)
|
|
40
|
+
|
|
41
|
+
label_map = {label.name:label.id for label in labels}
|
|
42
|
+
for segment in data["segments"]:
|
|
43
|
+
dye = segment["dye"]
|
|
44
|
+
text = segment["text"]
|
|
45
|
+
dye_id = label_map[dye]
|
|
46
|
+
|
|
47
|
+
content_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
48
|
+
positions = _find_all_sublist(full_ids, content_ids)
|
|
49
|
+
|
|
50
|
+
if len(positions) == 0:
|
|
51
|
+
raise ValueError(f"content not found in full_ids: {text!r}")
|
|
52
|
+
if len(positions) > 1:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"ambiguous match ({len(positions)} hits), "
|
|
55
|
+
f"two segments have identical content: {text!r}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
pos = positions[0]
|
|
59
|
+
for i in range(len(content_ids)):
|
|
60
|
+
dye_mask[pos + i] = dye_id
|
|
61
|
+
|
|
62
|
+
# 3. target部分:从apply_chat_template提取generation prompt
|
|
63
|
+
context_len = len(full_ids)
|
|
64
|
+
|
|
65
|
+
full_ids_with_gen: list[int] = tokenizer.apply_chat_template(
|
|
66
|
+
messages,
|
|
67
|
+
tokenize=True,
|
|
68
|
+
add_generation_prompt=True,
|
|
69
|
+
)["input_ids"] # TODO: 留着attention_mask或许有用
|
|
70
|
+
gen_prompt_ids = full_ids_with_gen[context_len:] # 只取新增部分
|
|
71
|
+
|
|
72
|
+
target_ids = tokenizer.encode(data["target"], add_special_tokens=False)
|
|
73
|
+
eos = [tokenizer.eos_token_id]
|
|
74
|
+
|
|
75
|
+
input_ids = full_ids + gen_prompt_ids + target_ids + eos
|
|
76
|
+
dye_mask = dye_mask + [-1] * (len(gen_prompt_ids) + len(target_ids) + 1)
|
|
77
|
+
|
|
78
|
+
target_mask = (
|
|
79
|
+
[False] * context_len
|
|
80
|
+
+ [False] * len(gen_prompt_ids)
|
|
81
|
+
+ [True] * len(target_ids)
|
|
82
|
+
+ [True] # eos算loss
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return input_ids, dye_mask, target_mask
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DyeDataset(Dataset):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
raw_data: list[dict],
|
|
92
|
+
tokenizer,
|
|
93
|
+
labels: list['DyeLabel']
|
|
94
|
+
):
|
|
95
|
+
self.dataset: list[dict] = []
|
|
96
|
+
|
|
97
|
+
for data in raw_data:
|
|
98
|
+
input_ids, dye_mask, target_mask = _build_sequence(
|
|
99
|
+
data, tokenizer, labels
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
self.dataset.append(
|
|
103
|
+
{
|
|
104
|
+
"input_ids": input_ids,
|
|
105
|
+
"dye_mask": dye_mask,
|
|
106
|
+
"target_mask": target_mask,
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def __len__(self):
|
|
111
|
+
return len(self.dataset)
|
|
112
|
+
|
|
113
|
+
def __getitem__(self, idx):
|
|
114
|
+
return self.dataset[idx]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _find_all_sublist(full, sub) -> list[int]:
|
|
118
|
+
"""返回sub在full中所有出现位置的起始index"""
|
|
119
|
+
n, m = len(full), len(sub)
|
|
120
|
+
if m == 0:
|
|
121
|
+
return []
|
|
122
|
+
return [i for i in range(n - m + 1) if full[i : i + m] == sub]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
import jsonl5
|
|
5
|
+
|
|
6
|
+
from . import DyeDataset
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from os import PathLike
|
|
10
|
+
|
|
11
|
+
from ..dye_label import DyeLabel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def from_jsonl(data_path: "PathLike | str", tokenizer, labels: list["DyeLabel"]):
|
|
15
|
+
with open(data_path) as f:
|
|
16
|
+
# raw_data = []
|
|
17
|
+
# for line in f:
|
|
18
|
+
# raw_data.append(json.loads(line))
|
|
19
|
+
raw_data = [json.loads(line) for line in f]
|
|
20
|
+
|
|
21
|
+
return DyeDataset(raw_data, tokenizer, labels)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def from_jsonl5(data_path: "PathLike | str", tokenizer, labels: list["DyeLabel"]):
|
|
25
|
+
with open(data_path) as f:
|
|
26
|
+
raw_data = jsonl5.load(f)
|
|
27
|
+
|
|
28
|
+
return DyeDataset(raw_data, tokenizer, labels)
|
tokendye-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
dataset/__init__.py
|
|
4
|
+
dataset/dye_dataset.py
|
|
5
|
+
dataset/functional.py
|
|
6
|
+
tokendye.egg-info/PKG-INFO
|
|
7
|
+
tokendye.egg-info/SOURCES.txt
|
|
8
|
+
tokendye.egg-info/dependency_links.txt
|
|
9
|
+
tokendye.egg-info/requires.txt
|
|
10
|
+
tokendye.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
dataset
|