persona-data 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- persona_data-0.1.0/PKG-INFO +127 -0
- persona_data-0.1.0/README.md +116 -0
- persona_data-0.1.0/pyproject.toml +16 -0
- persona_data-0.1.0/src/persona_data/__init__.py +0 -0
- persona_data-0.1.0/src/persona_data/environment.py +42 -0
- persona_data-0.1.0/src/persona_data/persona_guess.py +71 -0
- persona_data-0.1.0/src/persona_data/prompts.py +127 -0
- persona_data-0.1.0/src/persona_data/synth_persona.py +211 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: persona-data
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared dataset loading and prompt formatting for implicit-personalization projects
|
|
5
|
+
Requires-Dist: huggingface-hub>=0.30.0
|
|
6
|
+
Requires-Dist: python-dotenv>=1.0.0
|
|
7
|
+
Requires-Dist: torch>=2.0.0
|
|
8
|
+
Requires-Dist: numpy>=1.24.0
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# persona-data
|
|
13
|
+
|
|
14
|
+
[](https://implicit-personalization.github.io/persona-data/)
|
|
15
|
+
|
|
16
|
+
Shared dataset loading, prompt formatting, and environment utilities for the [implicit-personalization](https://github.com/implicit-personalization) projects.
|
|
17
|
+
|
|
18
|
+
## Overview
|
|
19
|
+
|
|
20
|
+
`persona-data` provides the common dataset and prompt helpers used across the persona projects:
|
|
21
|
+
|
|
22
|
+
- `SynthPersonaDataset` for persona profiles plus QA pairs
|
|
23
|
+
- `PersonaGuessDataset` for turn-based persona games
|
|
24
|
+
- prompt helpers for roleplay and multiple-choice evaluation
|
|
25
|
+
- environment helpers for seeds, devices, and artifact paths
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
Add as a uv git source in your project's `pyproject.toml`:
|
|
30
|
+
|
|
31
|
+
```toml
|
|
32
|
+
[project]
|
|
33
|
+
dependencies = ["persona-data"]
|
|
34
|
+
|
|
35
|
+
[tool.uv.sources]
|
|
36
|
+
persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Then run `uv sync`.
|
|
40
|
+
|
|
41
|
+
For local development alongside other repos, use an editable path source:
|
|
42
|
+
|
|
43
|
+
```toml
|
|
44
|
+
[tool.uv.sources]
|
|
45
|
+
persona-data = { path = "../persona-data", editable = true }
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Package layout
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
src/persona_data/
|
|
52
|
+
├── __init__.py
|
|
53
|
+
├── synth_persona.py # SynthPersonaDataset, PersonaDataset, PersonaData, QAPair, BiographySection
|
|
54
|
+
├── persona_guess.py # PersonaGuessDataset, GameRecord, Turn
|
|
55
|
+
├── prompts.py # format_roleplay_prompt, format_mc_question, format_messages
|
|
56
|
+
└── environment.py # load_env, set_seed, get_device, get_artifacts_dir
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Datasets
|
|
60
|
+
|
|
61
|
+
Each dataset is a module with its own types and a loader that downloads from Hugging Face, cached via `HF_HOME`.
|
|
62
|
+
|
|
63
|
+
### SynthPersona
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from persona_data.synth_persona import SynthPersonaDataset
|
|
67
|
+
|
|
68
|
+
dataset = SynthPersonaDataset()
|
|
69
|
+
|
|
70
|
+
persona = dataset[0]
|
|
71
|
+
persona.name # "Ethan Robinson"
|
|
72
|
+
persona.templated_view # short attribute-based system prompt
|
|
73
|
+
persona.biography_view # full biography text
|
|
74
|
+
persona.sections # list of BiographySection
|
|
75
|
+
|
|
76
|
+
qa_pairs = dataset.get_qa(persona.id, type="implicit", difficulty=[1, 2])
|
|
77
|
+
questions = dataset.questions(persona.id, type="explicit")
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
### PersonaGuess
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
from persona_data.persona_guess import PersonaGuessDataset
|
|
84
|
+
|
|
85
|
+
games = PersonaGuessDataset()
|
|
86
|
+
game = games[0]
|
|
87
|
+
turns = games.get_qa(game.game_id, player="A")
|
|
88
|
+
questions = games.questions(game.game_id, player="B")
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## Prompt formatting
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
from persona_data.prompts import format_messages, format_roleplay_prompt
|
|
95
|
+
|
|
96
|
+
system_prompt = format_roleplay_prompt(persona.biography_view)
|
|
97
|
+
|
|
98
|
+
messages = [
|
|
99
|
+
{"role": "system", "content": system_prompt},
|
|
100
|
+
{"role": "user", "content": "Where did you grow up?"},
|
|
101
|
+
{"role": "assistant", "content": "I grew up in Little Rock, Arkansas."},
|
|
102
|
+
]
|
|
103
|
+
full_prompt, response_start_idx = format_messages(messages, tokenizer)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
`format_roleplay_prompt` supports `mode="roleplay"` (default), `mode="conversational"`, and `mode="mc"`.
|
|
107
|
+
|
|
108
|
+
`format_messages` handles tokenizers that do not support the `"system"` role (for example Gemma 2) by merging system content into the first user message.
|
|
109
|
+
|
|
110
|
+
For multiple-choice evaluation, use `format_mc_question(qa)` and `mc_correct_letter(qa)`.
|
|
111
|
+
|
|
112
|
+
## Environment helpers
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
from persona_data.environment import load_env, set_seed, get_device, get_artifacts_dir
|
|
116
|
+
|
|
117
|
+
load_env() # loads .env from cwd (searches parent dirs)
|
|
118
|
+
set_seed(1337) # sets random, numpy, and torch seeds
|
|
119
|
+
device = get_device() # cuda > mps > cpu
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
## Used by
|
|
124
|
+
|
|
125
|
+
- [persona-vectors](https://github.com/implicit-personalization/persona-vectors) — activation extraction and steering
|
|
126
|
+
- [cues_attribution](https://github.com/implicit-personalization/io-analysis) — section-level ablation attribution
|
|
127
|
+
- [persona-2-lora](https://github.com/implicit-personalization/persona-2-lora) — LoRA-based persona internalization
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# persona-data
|
|
2
|
+
|
|
3
|
+
[](https://implicit-personalization.github.io/persona-data/)
|
|
4
|
+
|
|
5
|
+
Shared dataset loading, prompt formatting, and environment utilities for the [implicit-personalization](https://github.com/implicit-personalization) projects.
|
|
6
|
+
|
|
7
|
+
## Overview
|
|
8
|
+
|
|
9
|
+
`persona-data` provides the common dataset and prompt helpers used across the persona projects:
|
|
10
|
+
|
|
11
|
+
- `SynthPersonaDataset` for persona profiles plus QA pairs
|
|
12
|
+
- `PersonaGuessDataset` for turn-based persona games
|
|
13
|
+
- prompt helpers for roleplay and multiple-choice evaluation
|
|
14
|
+
- environment helpers for seeds, devices, and artifact paths
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
Add as a uv git source in your project's `pyproject.toml`:
|
|
19
|
+
|
|
20
|
+
```toml
|
|
21
|
+
[project]
|
|
22
|
+
dependencies = ["persona-data"]
|
|
23
|
+
|
|
24
|
+
[tool.uv.sources]
|
|
25
|
+
persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Then run `uv sync`.
|
|
29
|
+
|
|
30
|
+
For local development alongside other repos, use an editable path source:
|
|
31
|
+
|
|
32
|
+
```toml
|
|
33
|
+
[tool.uv.sources]
|
|
34
|
+
persona-data = { path = "../persona-data", editable = true }
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Package layout
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
src/persona_data/
|
|
41
|
+
├── __init__.py
|
|
42
|
+
├── synth_persona.py # SynthPersonaDataset, PersonaDataset, PersonaData, QAPair, BiographySection
|
|
43
|
+
├── persona_guess.py # PersonaGuessDataset, GameRecord, Turn
|
|
44
|
+
├── prompts.py # format_roleplay_prompt, format_mc_question, format_messages
|
|
45
|
+
└── environment.py # load_env, set_seed, get_device, get_artifacts_dir
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Datasets
|
|
49
|
+
|
|
50
|
+
Each dataset is a module with its own types and a loader that downloads from Hugging Face, cached via `HF_HOME`.
|
|
51
|
+
|
|
52
|
+
### SynthPersona
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
from persona_data.synth_persona import SynthPersonaDataset
|
|
56
|
+
|
|
57
|
+
dataset = SynthPersonaDataset()
|
|
58
|
+
|
|
59
|
+
persona = dataset[0]
|
|
60
|
+
persona.name # "Ethan Robinson"
|
|
61
|
+
persona.templated_view # short attribute-based system prompt
|
|
62
|
+
persona.biography_view # full biography text
|
|
63
|
+
persona.sections # list of BiographySection
|
|
64
|
+
|
|
65
|
+
qa_pairs = dataset.get_qa(persona.id, type="implicit", difficulty=[1, 2])
|
|
66
|
+
questions = dataset.questions(persona.id, type="explicit")
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### PersonaGuess
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from persona_data.persona_guess import PersonaGuessDataset
|
|
73
|
+
|
|
74
|
+
games = PersonaGuessDataset()
|
|
75
|
+
game = games[0]
|
|
76
|
+
turns = games.get_qa(game.game_id, player="A")
|
|
77
|
+
questions = games.questions(game.game_id, player="B")
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Prompt formatting
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
from persona_data.prompts import format_messages, format_roleplay_prompt
|
|
84
|
+
|
|
85
|
+
system_prompt = format_roleplay_prompt(persona.biography_view)
|
|
86
|
+
|
|
87
|
+
messages = [
|
|
88
|
+
{"role": "system", "content": system_prompt},
|
|
89
|
+
{"role": "user", "content": "Where did you grow up?"},
|
|
90
|
+
{"role": "assistant", "content": "I grew up in Little Rock, Arkansas."},
|
|
91
|
+
]
|
|
92
|
+
full_prompt, response_start_idx = format_messages(messages, tokenizer)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
`format_roleplay_prompt` supports `mode="roleplay"` (default), `mode="conversational"`, and `mode="mc"`.
|
|
96
|
+
|
|
97
|
+
`format_messages` handles tokenizers that do not support the `"system"` role (for example Gemma 2) by merging system content into the first user message.
|
|
98
|
+
|
|
99
|
+
For multiple-choice evaluation, use `format_mc_question(qa)` and `mc_correct_letter(qa)`.
|
|
100
|
+
|
|
101
|
+
## Environment helpers
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
from persona_data.environment import load_env, set_seed, get_device, get_artifacts_dir
|
|
105
|
+
|
|
106
|
+
load_env() # loads .env from cwd (searches parent dirs)
|
|
107
|
+
set_seed(1337) # sets random, numpy, and torch seeds
|
|
108
|
+
device = get_device() # cuda > mps > cpu
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
## Used by
|
|
113
|
+
|
|
114
|
+
- [persona-vectors](https://github.com/implicit-personalization/persona-vectors) — activation extraction and steering
|
|
115
|
+
- [cues_attribution](https://github.com/implicit-personalization/io-analysis) — section-level ablation attribution
|
|
116
|
+
- [persona-2-lora](https://github.com/implicit-personalization/persona-2-lora) — LoRA-based persona internalization
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "persona-data"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Shared dataset loading and prompt formatting for implicit-personalization projects"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"huggingface-hub>=0.30.0",
|
|
9
|
+
"python-dotenv>=1.0.0",
|
|
10
|
+
"torch>=2.0.0",
|
|
11
|
+
"numpy>=1.24.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[build-system]
|
|
15
|
+
requires = ["uv_build>=0.11.3,<0.12"]
|
|
16
|
+
build-backend = "uv_build"
|
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_env() -> None:
|
|
11
|
+
"""Load environment variables from .env file.
|
|
12
|
+
|
|
13
|
+
Searches the current working directory and parent directories.
|
|
14
|
+
"""
|
|
15
|
+
load_dotenv()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_artifacts_dir() -> Path:
|
|
19
|
+
"""Get the root artifacts directory."""
|
|
20
|
+
return Path(os.environ.get("ARTIFACTS_DIR", "artifacts"))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_seed(seed: int) -> None:
|
|
24
|
+
"""Set random seed for reproducibility."""
|
|
25
|
+
random.seed(seed)
|
|
26
|
+
np.random.seed(seed)
|
|
27
|
+
torch.manual_seed(seed)
|
|
28
|
+
if torch.cuda.is_available():
|
|
29
|
+
torch.cuda.manual_seed(seed)
|
|
30
|
+
torch.cuda.manual_seed_all(seed)
|
|
31
|
+
if torch.backends.mps.is_available():
|
|
32
|
+
torch.mps.manual_seed(seed)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_device() -> torch.device:
|
|
36
|
+
"""Determine the best available device."""
|
|
37
|
+
if torch.cuda.is_available():
|
|
38
|
+
return torch.device("cuda")
|
|
39
|
+
elif torch.backends.mps.is_available():
|
|
40
|
+
return torch.device("mps")
|
|
41
|
+
else:
|
|
42
|
+
return torch.device("cpu")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Iterator, Literal
|
|
5
|
+
|
|
6
|
+
from huggingface_hub import hf_hub_download
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Turn:
|
|
11
|
+
round: int
|
|
12
|
+
asker: Literal["A", "B"]
|
|
13
|
+
question: str
|
|
14
|
+
answer: str
|
|
15
|
+
|
|
16
|
+
def __repr__(self):
|
|
17
|
+
return f"Turn(round={self.round}, asker={self.asker!r})"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class GameRecord:
|
|
22
|
+
game_id: str
|
|
23
|
+
persona_a_id: str
|
|
24
|
+
persona_b_id: str
|
|
25
|
+
turns: list[Turn]
|
|
26
|
+
|
|
27
|
+
def __repr__(self):
|
|
28
|
+
return f"GameRecord(game_id={self.game_id!r}, turns={len(self.turns)})"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PersonaGuessDataset:
|
|
32
|
+
"""PersonaGuess game dataset loaded from HuggingFace."""
|
|
33
|
+
|
|
34
|
+
DEFAULT_REPO = "implicit-personalization/persona-guess"
|
|
35
|
+
|
|
36
|
+
def __init__(self, hf_repo: str = DEFAULT_REPO) -> None:
|
|
37
|
+
path = Path(hf_hub_download(hf_repo, "games.jsonl", repo_type="dataset"))
|
|
38
|
+
with open(path) as f:
|
|
39
|
+
self._games: list[GameRecord] = [
|
|
40
|
+
GameRecord(
|
|
41
|
+
game_id=d["game_id"],
|
|
42
|
+
persona_a_id=d["persona_a_id"],
|
|
43
|
+
persona_b_id=d["persona_b_id"],
|
|
44
|
+
turns=[Turn(**t) for t in d["turns"]],
|
|
45
|
+
)
|
|
46
|
+
for d in (json.loads(line) for line in f)
|
|
47
|
+
]
|
|
48
|
+
self._games_by_id: dict[str, GameRecord] = {g.game_id: g for g in self._games}
|
|
49
|
+
|
|
50
|
+
def __repr__(self) -> str:
|
|
51
|
+
return f"PersonaGuessDataset(n_games={len(self._games)})"
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
return len(self._games)
|
|
55
|
+
|
|
56
|
+
def __iter__(self) -> Iterator[GameRecord]:
|
|
57
|
+
return iter(self._games)
|
|
58
|
+
|
|
59
|
+
def __getitem__(self, idx: int) -> GameRecord:
|
|
60
|
+
return self._games[idx]
|
|
61
|
+
|
|
62
|
+
def get_qa(
|
|
63
|
+
self, game_id: str, player: Literal["A", "B"] | None = None
|
|
64
|
+
) -> list[Turn]:
|
|
65
|
+
game = self._games_by_id[game_id]
|
|
66
|
+
return [t for t in game.turns if player is None or t.asker == player]
|
|
67
|
+
|
|
68
|
+
def questions(
|
|
69
|
+
self, game_id: str, player: Literal["A", "B"] | None = None
|
|
70
|
+
) -> list[str]:
|
|
71
|
+
return [t.question for t in self.get_qa(game_id, player)]
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from persona_data.synth_persona import QAPair
|
|
4
|
+
|
|
5
|
+
_LETTERS = "ABCDE"
|
|
6
|
+
|
|
7
|
+
BASE_ROLEPLAY_PROMPT = """\
|
|
8
|
+
You are roleplaying as a specific person in a conversation.
|
|
9
|
+
Stay fully in character. Be truthful to the profile below.
|
|
10
|
+
Do not mention that you are an AI model.
|
|
11
|
+
|
|
12
|
+
### Person profile:
|
|
13
|
+
|
|
14
|
+
{persona}"""
|
|
15
|
+
|
|
16
|
+
_CONVERSATIONAL_SUFFIX = "\n\nAnswer naturally and conversationally as this person."
|
|
17
|
+
|
|
18
|
+
EMPTY_PERSONA_PLACEHOLDER = "Assistant"
|
|
19
|
+
MC_ANSWER_ONLY_INSTRUCTION = "Answer only with the correct choice label (A, B, C, ...)."
|
|
20
|
+
PromptMode = Literal["roleplay", "conversational", "mc"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def format_roleplay_prompt(
|
|
24
|
+
persona: str = EMPTY_PERSONA_PLACEHOLDER, mode: PromptMode = "roleplay"
|
|
25
|
+
) -> str:
|
|
26
|
+
"""System prompt using any persona view.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
persona: The persona text (templated or biography view).
|
|
30
|
+
mode: Prompt style selector.
|
|
31
|
+
- "roleplay": plain persona prompt
|
|
32
|
+
- "conversational": persona prompt with a natural chat instruction
|
|
33
|
+
- "mc": persona prompt with a multiple-choice answer constraint
|
|
34
|
+
"""
|
|
35
|
+
base = BASE_ROLEPLAY_PROMPT.format(persona=persona)
|
|
36
|
+
if mode == "conversational":
|
|
37
|
+
return base + _CONVERSATIONAL_SUFFIX
|
|
38
|
+
if mode == "mc":
|
|
39
|
+
return f"{base.rstrip()}\n\n{MC_ANSWER_ONLY_INSTRUCTION}"
|
|
40
|
+
return base
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _format_mc_question_prompt(qa: QAPair) -> str:
|
|
44
|
+
lines = [qa.question, ""]
|
|
45
|
+
for i, choice in enumerate(qa.choices):
|
|
46
|
+
lines.append(f"{_LETTERS[i]}. {choice}")
|
|
47
|
+
return "\n".join(lines)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def format_mc_question(qa: QAPair) -> str:
|
|
51
|
+
"""Format an MC question with lettered choices for model evaluation."""
|
|
52
|
+
return f"{_format_mc_question_prompt(qa).rstrip()}\n\n{MC_ANSWER_ONLY_INSTRUCTION}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def mc_correct_letter(qa: QAPair) -> str:
|
|
56
|
+
"""Return the letter (A–E) for the correct choice of an MC question."""
|
|
57
|
+
if qa.correct_choice_index is None:
|
|
58
|
+
raise ValueError(f"QAPair {qa.qid!r} has no correct_choice_index")
|
|
59
|
+
return _LETTERS[qa.correct_choice_index]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _supports_system_role(tokenizer) -> bool:
|
|
63
|
+
"""Check if tokenizer's chat template supports the 'system' role."""
|
|
64
|
+
try:
|
|
65
|
+
tokenizer.apply_chat_template(
|
|
66
|
+
[{"role": "system", "content": "test"}],
|
|
67
|
+
tokenize=False,
|
|
68
|
+
)
|
|
69
|
+
return True
|
|
70
|
+
except Exception:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def normalize_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
75
|
+
"""Merge any leading system message into the first user message.
|
|
76
|
+
|
|
77
|
+
Only needed when the tokenizer's chat template doesn't support
|
|
78
|
+
the "system" role (e.g., Gemma 2).
|
|
79
|
+
"""
|
|
80
|
+
if not messages or messages[0]["role"] != "system":
|
|
81
|
+
return messages
|
|
82
|
+
system_content = messages[0]["content"]
|
|
83
|
+
rest = list(messages[1:])
|
|
84
|
+
if rest and rest[0]["role"] == "user" and system_content:
|
|
85
|
+
rest[0] = {
|
|
86
|
+
"role": "user",
|
|
87
|
+
"content": f"{system_content}\n\n{rest[0]['content']}",
|
|
88
|
+
}
|
|
89
|
+
return rest
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def format_messages(messages: list[dict[str, str]], tokenizer) -> tuple[str, int]:
|
|
93
|
+
"""Format a conversation for the model using its chat template.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
97
|
+
Can include "system", "user", and "assistant" roles.
|
|
98
|
+
tokenizer: The tokenizer with chat template support.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
full_prompt: The full formatted prompt as a string.
|
|
102
|
+
response_start_idx: The token index of the first token in the last
|
|
103
|
+
assistant message.
|
|
104
|
+
"""
|
|
105
|
+
supports_system = _supports_system_role(tokenizer)
|
|
106
|
+
if not supports_system:
|
|
107
|
+
messages = normalize_messages(messages)
|
|
108
|
+
|
|
109
|
+
full_prompt = tokenizer.apply_chat_template(
|
|
110
|
+
messages, tokenize=False, add_generation_prompt=False
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if len(messages) <= 1:
|
|
114
|
+
prompt_without_response = tokenizer.apply_chat_template(
|
|
115
|
+
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
prompt_without_response = tokenizer.apply_chat_template(
|
|
119
|
+
messages[:-1],
|
|
120
|
+
tokenize=True,
|
|
121
|
+
add_generation_prompt=True,
|
|
122
|
+
return_tensors="pt",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
response_start_idx = prompt_without_response["input_ids"].shape[1]
|
|
126
|
+
|
|
127
|
+
return full_prompt, response_start_idx
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Iterator, Literal
|
|
6
|
+
|
|
7
|
+
# NOTE: The loader intentionally drops provenance/curation fields that are not used
|
|
8
|
+
# by persona-data, persona-ui, or persona-vectors. We keep `evidence_sids` because
|
|
9
|
+
# the ablation tooling still relies on it.
|
|
10
|
+
# Dropped fields: design_notes, family_name, source_candidate_*, bank_id, axis,
|
|
11
|
+
# curation_note, validation, evidence_claims, evidence_quotes.
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class QAPair:
|
|
16
|
+
qid: str
|
|
17
|
+
type: Literal["explicit", "implicit"]
|
|
18
|
+
question: str
|
|
19
|
+
answer: str
|
|
20
|
+
difficulty: int # 1 = easy, 2 = medium, 3 = hard
|
|
21
|
+
answer_format: str = "" # "free_text" or "choice"
|
|
22
|
+
choices: list[str] = field(default_factory=list)
|
|
23
|
+
correct_choice_index: int | None = None
|
|
24
|
+
evidence_sids: list[str] = field(default_factory=list)
|
|
25
|
+
tags: list[str] = field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
def __repr__(self):
|
|
28
|
+
return f"QAPair(qid={self.qid!r}, type={self.type!r}, difficulty={self.difficulty})"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class Statement:
|
|
33
|
+
sid: str
|
|
34
|
+
category: str
|
|
35
|
+
claim: str
|
|
36
|
+
support: list[dict] = field(default_factory=list)
|
|
37
|
+
confidence: str = ""
|
|
38
|
+
|
|
39
|
+
def __repr__(self):
|
|
40
|
+
return f"Statement(sid={self.sid!r}, category={self.category!r})"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class BiographySection:
|
|
45
|
+
"""A semantically coherent segment of a persona biography."""
|
|
46
|
+
|
|
47
|
+
section_id: str
|
|
48
|
+
category: str # e.g. "upbringing", "education", "career"
|
|
49
|
+
title: str
|
|
50
|
+
text: str # Concatenated paragraph texts
|
|
51
|
+
|
|
52
|
+
def __repr__(self):
|
|
53
|
+
return f"BiographySection(id={self.section_id!r}, category={self.category!r})"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class PersonaData:
|
|
58
|
+
id: str
|
|
59
|
+
persona: dict
|
|
60
|
+
templated_view: str
|
|
61
|
+
biography_view: str
|
|
62
|
+
statements_view: str = ""
|
|
63
|
+
# transcript: str = "" # NOTE: This is not needed for our usecase for now
|
|
64
|
+
sections: list[BiographySection] = field(default_factory=list)
|
|
65
|
+
statements: list[Statement] = field(default_factory=list)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def name(self) -> str:
|
|
69
|
+
return f"{self.persona['first_name']} {self.persona['last_name']}"
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def sections_by_id(self) -> dict[str, BiographySection]:
|
|
73
|
+
return {section.section_id: section for section in self.sections}
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def sections_by_category(self) -> dict[str, list[BiographySection]]:
|
|
77
|
+
grouped: dict[str, list[BiographySection]] = defaultdict(list)
|
|
78
|
+
for section in self.sections:
|
|
79
|
+
grouped[section.category].append(section)
|
|
80
|
+
return dict(grouped)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def section_categories(self) -> list[str]:
|
|
84
|
+
"""Unique section categories in order."""
|
|
85
|
+
return list(self.sections_by_category)
|
|
86
|
+
|
|
87
|
+
def get_section(self, section_id: str) -> BiographySection | None:
|
|
88
|
+
return self.sections_by_id.get(section_id)
|
|
89
|
+
|
|
90
|
+
def get_sections_by_category(self, category: str) -> list[BiographySection]:
|
|
91
|
+
return self.sections_by_category.get(category, [])
|
|
92
|
+
|
|
93
|
+
def __repr__(self):
|
|
94
|
+
return f"PersonaData(id={self.id!r}, name={self.name!r})"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class PersonaDataset:
|
|
98
|
+
"""Persona dataset loaded from local JSONL files."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, personas_path: Path | str, qa_path: Path | str) -> None:
|
|
101
|
+
self._personas: list[PersonaData] = []
|
|
102
|
+
self._personas_by_id: dict[str, PersonaData] = {}
|
|
103
|
+
with open(personas_path) as f:
|
|
104
|
+
for line in f:
|
|
105
|
+
if not line.strip():
|
|
106
|
+
continue
|
|
107
|
+
d = json.loads(line)
|
|
108
|
+
sections = [
|
|
109
|
+
BiographySection(
|
|
110
|
+
section_id=sec["section_id"],
|
|
111
|
+
category=sec["category"],
|
|
112
|
+
title=sec["title"],
|
|
113
|
+
text="\n\n".join(p["text"] for p in sec.get("paragraphs", [])),
|
|
114
|
+
)
|
|
115
|
+
for sec in d.get("sections", [])
|
|
116
|
+
]
|
|
117
|
+
persona = PersonaData(
|
|
118
|
+
id=d["id"],
|
|
119
|
+
persona=d["persona"],
|
|
120
|
+
templated_view=d["templated_view"],
|
|
121
|
+
biography_view=d.get("biography_view", ""),
|
|
122
|
+
statements_view=d.get("statements_view", ""),
|
|
123
|
+
sections=sections,
|
|
124
|
+
statements=[
|
|
125
|
+
Statement(
|
|
126
|
+
sid=s["sid"],
|
|
127
|
+
category=s["category"],
|
|
128
|
+
claim=s["claim"],
|
|
129
|
+
support=s.get("support", []),
|
|
130
|
+
confidence=s.get("confidence", ""),
|
|
131
|
+
)
|
|
132
|
+
for s in d.get("statements", [])
|
|
133
|
+
],
|
|
134
|
+
)
|
|
135
|
+
self._personas.append(persona)
|
|
136
|
+
self._personas_by_id[persona.id] = persona
|
|
137
|
+
|
|
138
|
+
self._qa: dict[str, list[QAPair]] = defaultdict(list)
|
|
139
|
+
with open(qa_path) as f:
|
|
140
|
+
for line in f:
|
|
141
|
+
if not line.strip():
|
|
142
|
+
continue
|
|
143
|
+
d = json.loads(line)
|
|
144
|
+
self._qa[d["id"]].append(
|
|
145
|
+
QAPair(
|
|
146
|
+
qid=d["qid"],
|
|
147
|
+
type=d["type"],
|
|
148
|
+
question=d["question"],
|
|
149
|
+
answer=d["answer"],
|
|
150
|
+
difficulty=d["difficulty"],
|
|
151
|
+
answer_format=d.get("answer_format", "free_text"),
|
|
152
|
+
choices=d.get("choices", []),
|
|
153
|
+
correct_choice_index=d.get("correct_choice_index"),
|
|
154
|
+
evidence_sids=d.get("evidence_sids", []),
|
|
155
|
+
tags=d.get("tags", []),
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def __repr__(self) -> str:
|
|
160
|
+
return f"{type(self).__name__}(n_personas={len(self._personas)})"
|
|
161
|
+
|
|
162
|
+
def __len__(self) -> int:
|
|
163
|
+
return len(self._personas)
|
|
164
|
+
|
|
165
|
+
def __iter__(self) -> Iterator[PersonaData]:
|
|
166
|
+
return iter(self._personas)
|
|
167
|
+
|
|
168
|
+
def __getitem__(self, idx: int) -> PersonaData:
|
|
169
|
+
return self._personas[idx]
|
|
170
|
+
|
|
171
|
+
def get_persona(self, persona_id: str) -> PersonaData | None:
|
|
172
|
+
return self._personas_by_id.get(persona_id)
|
|
173
|
+
|
|
174
|
+
def get_qa(
|
|
175
|
+
self,
|
|
176
|
+
persona_id: str,
|
|
177
|
+
type: Literal["explicit", "implicit"] | None = None,
|
|
178
|
+
difficulty: int | list[int] | None = None,
|
|
179
|
+
) -> list[QAPair]:
|
|
180
|
+
"""Return QA pairs for a persona, optionally filtered by type and/or difficulty."""
|
|
181
|
+
pairs = self._qa.get(persona_id, [])
|
|
182
|
+
if type is not None:
|
|
183
|
+
pairs = [p for p in pairs if p.type == type]
|
|
184
|
+
if difficulty is not None:
|
|
185
|
+
levels = {difficulty} if isinstance(difficulty, int) else set(difficulty)
|
|
186
|
+
pairs = [p for p in pairs if p.difficulty in levels]
|
|
187
|
+
return pairs
|
|
188
|
+
|
|
189
|
+
def questions(
|
|
190
|
+
self,
|
|
191
|
+
persona_id: str,
|
|
192
|
+
type: Literal["explicit", "implicit"] | None = None,
|
|
193
|
+
difficulty: int | list[int] | None = None,
|
|
194
|
+
) -> list[str]:
|
|
195
|
+
"""Like get_qa but returns question strings only."""
|
|
196
|
+
return [qa.question for qa in self.get_qa(persona_id, type, difficulty)]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class SynthPersonaDataset(PersonaDataset):
|
|
200
|
+
"""SynthPersona dataset loaded from HuggingFace."""
|
|
201
|
+
|
|
202
|
+
def __init__(self, hf_repo: str = "implicit-personalization/synth-persona") -> None:
|
|
203
|
+
from huggingface_hub import hf_hub_download
|
|
204
|
+
|
|
205
|
+
# HF Hub caches locally under HF_HOME so repeat runs are instant.
|
|
206
|
+
super().__init__(
|
|
207
|
+
personas_path=hf_hub_download(
|
|
208
|
+
hf_repo, "dataset_personas.jsonl", repo_type="dataset"
|
|
209
|
+
),
|
|
210
|
+
qa_path=hf_hub_download(hf_repo, "dataset_qa.jsonl", repo_type="dataset"),
|
|
211
|
+
)
|