sdialog 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdialog-0.0.1/LICENSE +21 -0
- sdialog-0.0.1/PKG-INFO +34 -0
- sdialog-0.0.1/README.md +7 -0
- sdialog-0.0.1/pyproject.toml +27 -0
- sdialog-0.0.1/requirements.txt +12 -0
- sdialog-0.0.1/setup.cfg +4 -0
- sdialog-0.0.1/src/sdialog/__init__.py +131 -0
- sdialog-0.0.1/src/sdialog/datasets.py +262 -0
- sdialog-0.0.1/src/sdialog/generators.py +82 -0
- sdialog-0.0.1/src/sdialog/orchestrators.py +224 -0
- sdialog-0.0.1/src/sdialog/personas.py +299 -0
- sdialog-0.0.1/src/sdialog/util.py +11 -0
- sdialog-0.0.1/src/sdialog.egg-info/PKG-INFO +34 -0
- sdialog-0.0.1/src/sdialog.egg-info/SOURCES.txt +15 -0
- sdialog-0.0.1/src/sdialog.egg-info/dependency_links.txt +1 -0
- sdialog-0.0.1/src/sdialog.egg-info/requires.txt +12 -0
- sdialog-0.0.1/src/sdialog.egg-info/top_level.txt +1 -0
sdialog-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Idiap Research Institute
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
sdialog-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sdialog
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Synthetic Dialogue Generation and Analysis
|
|
5
|
+
Author-email: Sergio Burdisso <sergio.burdisso@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/idiap/sdialog
|
|
8
|
+
Project-URL: Issues, https://github.com/idiap/sdialog/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.9
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: print-color
|
|
15
|
+
Requires-Dist: langchain
|
|
16
|
+
Requires-Dist: langchain-ollama
|
|
17
|
+
Requires-Dist: tqdm
|
|
18
|
+
Requires-Dist: plotly
|
|
19
|
+
Requires-Dist: sentence-transformers
|
|
20
|
+
Requires-Dist: pandas
|
|
21
|
+
Requires-Dist: tenacity
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: flake8
|
|
24
|
+
Requires-Dist: pytest
|
|
25
|
+
Requires-Dist: ollama
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# SDialog
|
|
29
|
+
|
|
30
|
+
Synthetic Dialogue Generation and Analysis
|
|
31
|
+
|
|
32
|
+
_(Comming soon)_
|
|
33
|
+
|
|
34
|
+
This package requires `Ollama` running is your system.
|
sdialog-0.0.1/README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools >= 77.0.3"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "sdialog"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Sergio Burdisso", email="sergio.burdisso@gmail.com" },
|
|
10
|
+
]
|
|
11
|
+
description = "Synthetic Dialogue Generation and Analysis"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.9"
|
|
14
|
+
dynamic = ["dependencies"]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Operating System :: OS Independent",
|
|
18
|
+
]
|
|
19
|
+
license = "MIT"
|
|
20
|
+
license-files = ["LICEN[CS]E*"]
|
|
21
|
+
|
|
22
|
+
[tool.setuptools.dynamic]
|
|
23
|
+
dependencies = {file = ["requirements.txt"]}
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
Homepage = "https://github.com/idiap/sdialog"
|
|
27
|
+
Issues = "https://github.com/idiap/sdialog/issues"
|
sdialog-0.0.1/setup.cfg
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from typing import List, Union, Optional
|
|
5
|
+
from print_color import print
|
|
6
|
+
|
|
7
|
+
from .util import make_serializable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Turn(BaseModel):
|
|
11
|
+
speaker: Optional[str]
|
|
12
|
+
text: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Event(BaseModel):
|
|
16
|
+
agent: Optional[str] = None # "user", "system"
|
|
17
|
+
action: str # "utter", "instruct"
|
|
18
|
+
actionLabel: Optional[str] = None # action label (e.g. type of instruct)
|
|
19
|
+
text: str # the content of the event
|
|
20
|
+
timestamp: int # timestemp
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Dialog(BaseModel):
|
|
24
|
+
formatVersion: Optional[str] = "0.0.5" # Version of the format
|
|
25
|
+
model: Optional[str] = None # the model used to generate the dialogue
|
|
26
|
+
seed: Optional[int] = None # the seed used to generated
|
|
27
|
+
dialogId: Optional[int] = None
|
|
28
|
+
complete: Optional[bool] = None
|
|
29
|
+
scenario: Optional[Union[dict, str]] = None # the scenario used to generated the dialogue
|
|
30
|
+
turns: List[Turn] # the list of turns of the conversation
|
|
31
|
+
events: Optional[List[Event]] = None
|
|
32
|
+
|
|
33
|
+
def __len__(self):
|
|
34
|
+
return len(self.turns)
|
|
35
|
+
|
|
36
|
+
def description(self, turn_template: str = "{speaker}: {text}"):
|
|
37
|
+
return "\n".join(turn_template.format(speaker=turn.speaker, text=turn.text.replace("\n", " "))
|
|
38
|
+
for turn in self.turns)
|
|
39
|
+
|
|
40
|
+
def json(self, string: bool = False, indent: int = None):
|
|
41
|
+
data = self.model_dump()
|
|
42
|
+
make_serializable(data)
|
|
43
|
+
return json.dumps(data, indent=indent) if string else data
|
|
44
|
+
|
|
45
|
+
def print(self, *a, **kw):
|
|
46
|
+
print_dialog(self, *a, **kw)
|
|
47
|
+
|
|
48
|
+
def to_file(self, path: str, type: str = "auto"): # type = "txt", "json" or "auto" which get's the type from the file extention
|
|
49
|
+
if type == "auto":
|
|
50
|
+
type = "json" if path.endswith(".json") else "txt"
|
|
51
|
+
|
|
52
|
+
with open(path, "w") as writer:
|
|
53
|
+
if type == "json":
|
|
54
|
+
writer.write(self.json(string=True))
|
|
55
|
+
else:
|
|
56
|
+
writer.write(self.description())
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def from_file(path: str, type: str = "auto"): # type = "txt", "json" or "auto" which get's the type from the file extention
|
|
60
|
+
if type == "auto":
|
|
61
|
+
type = "json" if path.endswith(".json") else "txt"
|
|
62
|
+
|
|
63
|
+
with open(path) as reader:
|
|
64
|
+
if type == "json":
|
|
65
|
+
return Dialog.model_validate(json.load(reader))
|
|
66
|
+
|
|
67
|
+
lines = reader.read().split("\n")
|
|
68
|
+
|
|
69
|
+
return Dialog(turns=[Turn(speaker=line[:line.index(":")].strip(),
|
|
70
|
+
text=line[line.index(":") + 1:].strip())
|
|
71
|
+
for line in lines if line])
|
|
72
|
+
|
|
73
|
+
# TODO: add from_dict as an alias of (so we don't have to use .model_validate())
|
|
74
|
+
|
|
75
|
+
__str__ = description
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Instruction(BaseModel):
|
|
79
|
+
text: str = None
|
|
80
|
+
events: Optional[Union[Event, List[Event]]] = None # extra events
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def print_dialog(dialog: Union[Dialog, dict], scenario: bool = False, orchestration: bool = False):
|
|
84
|
+
if type(dialog) == dict:
|
|
85
|
+
dialog = Dialog.model_validate(dialog)
|
|
86
|
+
|
|
87
|
+
speaker_tag_colors = ["red", "blue", "yellow", "cyan", "green", "magenta", "purple"]
|
|
88
|
+
speaker_utt_colors = ["grey", "white"]
|
|
89
|
+
# speaker_utt_colors = ["black", "grey"]
|
|
90
|
+
|
|
91
|
+
if dialog.dialogId:
|
|
92
|
+
print(dialog.dialogId, tag="dialog_id", tag_color="purple", color="magenta", format="bold")
|
|
93
|
+
if dialog.complete:
|
|
94
|
+
print(dialog.complete, tag="complete", tag_color="purple", color="magenta", format="bold")
|
|
95
|
+
if dialog.model:
|
|
96
|
+
print(dialog.model, tag="model", tag_color="purple", color="magenta", format="bold")
|
|
97
|
+
if dialog.seed:
|
|
98
|
+
print(dialog.seed, tag="seed", tag_color="purple", color="magenta", format="bold")
|
|
99
|
+
if scenario and dialog.scenario:
|
|
100
|
+
print("", tag="scenario", tag_color="purple", color="magenta", format="bold")
|
|
101
|
+
if type(dialog.scenario) == str:
|
|
102
|
+
print(dialog.scenario, color="magenta")
|
|
103
|
+
else:
|
|
104
|
+
print(json.dumps(dialog.scenario, indent=2), color="magenta")
|
|
105
|
+
|
|
106
|
+
print("--- Dialogue Begins ---", color="magenta", format="bold")
|
|
107
|
+
speakers = sorted(list(set(turn.speaker for turn in dialog.turns)))
|
|
108
|
+
if orchestration:
|
|
109
|
+
dialog = dialog.model_copy()
|
|
110
|
+
dialog.turns = [Turn(speaker=e.agent, text=e.text) if e.action == "utter"
|
|
111
|
+
else (
|
|
112
|
+
Turn(speaker=e.agent, text=f"[pick_suggestion] {e.text}") if e.action == "pick_suggestion"
|
|
113
|
+
else
|
|
114
|
+
Turn(speaker=e.action, text=f"({e.agent}) {e.text}"))
|
|
115
|
+
for e in dialog.events]
|
|
116
|
+
|
|
117
|
+
for ix, turn in enumerate(dialog.turns):
|
|
118
|
+
speaker = turn.speaker
|
|
119
|
+
|
|
120
|
+
if speaker not in speakers:
|
|
121
|
+
tag_color = "yellow"
|
|
122
|
+
color = "purple"
|
|
123
|
+
else:
|
|
124
|
+
tag_color = speaker_tag_colors[speakers.index(speaker) % len(speaker_tag_colors)]
|
|
125
|
+
color = speaker_utt_colors[speakers.index(speaker) % len(speaker_utt_colors)]
|
|
126
|
+
|
|
127
|
+
print(turn.text,
|
|
128
|
+
tag=speaker,
|
|
129
|
+
tag_color=tag_color,
|
|
130
|
+
color=color)
|
|
131
|
+
print("--- Dialogue Ends ---", color="magenta", format="bold")
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from tqdm.auto import tqdm
|
|
6
|
+
|
|
7
|
+
from . import Dialog, Turn, Event
|
|
8
|
+
from .personas import Persona, PersonaAgent
|
|
9
|
+
from .orchestrators import InstructionListOrchestrator, SimpleResponseOrchestrator
|
|
10
|
+
|
|
11
|
+
class STAR:
|
|
12
|
+
_path = None
|
|
13
|
+
_speakers = ["User", "Wizard"]
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def set_path(path):
|
|
17
|
+
STAR._path = path
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def read_graph(task_name, as_dot: bool = True):
|
|
21
|
+
with open(os.path.join(STAR._path, f"tasks/{task_name}/{task_name}.json")) as reader:
|
|
22
|
+
if not as_dot:
|
|
23
|
+
return json.load(reader)["graph"]
|
|
24
|
+
dot_edges = ";\n".join(f" {a} -> {b}" for a,b in json.load(reader)["graph"].items())
|
|
25
|
+
|
|
26
|
+
return "digraph %s {\n%s\n}" % (task_name, dot_edges)
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def read_graph_responses(task_name, as_dict: bool = False):
|
|
30
|
+
with open(os.path.join(STAR._path, f"tasks/{task_name}/responses.json")) as reader:
|
|
31
|
+
responses = json.load(reader)
|
|
32
|
+
responses = {key:re.sub(r"{(.+?)(?::\w+?)?}", lambda m:m.group(1).upper(), value)
|
|
33
|
+
for key, value in responses.items()
|
|
34
|
+
if key != "out_of_scope"}
|
|
35
|
+
return responses if as_dict else json.dumps(responses, indent=2)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def get_dialog(id):
|
|
39
|
+
with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
|
|
40
|
+
dialog = json.load(reader)
|
|
41
|
+
|
|
42
|
+
for e in dialog["Events"]:
|
|
43
|
+
if e["Agent"] == "Wizard":
|
|
44
|
+
e["Agent"] = "System"
|
|
45
|
+
|
|
46
|
+
return Dialog(
|
|
47
|
+
dialogId=id,
|
|
48
|
+
scenario=dialog["Scenario"],
|
|
49
|
+
turns=[Turn(speaker=e["Agent"], text=e["Text"])
|
|
50
|
+
for e in dialog["Events"]
|
|
51
|
+
if e["Action"] in ["utter", "pick_suggestion"]],
|
|
52
|
+
events=[Event(agent=e["Agent"],
|
|
53
|
+
action=e["Action"],
|
|
54
|
+
actionLabel=e["ActionLabel"] if "ActionLabel" in e else None,
|
|
55
|
+
text=e["Text"],
|
|
56
|
+
timestamp=e["UnixTime"])
|
|
57
|
+
for e in dialog["Events"]
|
|
58
|
+
if "Text" in e]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def get_dialogs(domain: str = None, task_name: str = None, happy: bool = None, multitask: bool = None):
|
|
63
|
+
dialogs = []
|
|
64
|
+
for fname in tqdm(os.listdir(os.path.join(STAR._path, f"dialogues/")), desc="Reading dialogs", leave=False):
|
|
65
|
+
if not fname.endswith(".json"):
|
|
66
|
+
continue
|
|
67
|
+
dialog_id = int(os.path.splitext(fname)[0])
|
|
68
|
+
scenario = STAR.get_dialog_scenario(dialog_id)
|
|
69
|
+
|
|
70
|
+
if (domain is None or domain in scenario["Domains"]) and \
|
|
71
|
+
(happy is None or scenario["Happy"] == happy) and \
|
|
72
|
+
(multitask is None or scenario["MultiTask"] == multitask) and \
|
|
73
|
+
(task_name is None or any(capability["Task"] == task_name for capability in scenario["WizardCapabilities"])):
|
|
74
|
+
dialogs.append(STAR.get_dialog(dialog_id))
|
|
75
|
+
return dialogs
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def get_dialog_scenario(id):
|
|
79
|
+
with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
|
|
80
|
+
return json.load(reader)["Scenario"]
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def get_dialog_first_turn(id, speaker: str = None):
|
|
84
|
+
with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
|
|
85
|
+
for event in json.load(reader)["Events"]:
|
|
86
|
+
turn_speaker = event["Agent"]
|
|
87
|
+
if speaker == None and turn_speaker in STAR._speakers:
|
|
88
|
+
return Turn(speaker=turn_speaker, text=event["Text"])
|
|
89
|
+
elif turn_speaker == speaker:
|
|
90
|
+
return Turn(speaker=turn_speaker, text=event["Text"])
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def get_dialog_task_names(id):
|
|
94
|
+
scenario = STAR.get_dialog_scenario(id)
|
|
95
|
+
return [task["Task"] for task in scenario["WizardCapabilities"]]
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def get_dialog_responses(id):
|
|
99
|
+
tasks = STAR.get_dialog_task_names(id)
|
|
100
|
+
return [STAR.read_graph_responses(task, as_dict=True) for task in tasks]
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def get_dialog_graphs(id):
|
|
104
|
+
tasks = STAR.get_dialog_task_names(id)
|
|
105
|
+
return [STAR.read_graph(task, as_dot=False) for task in tasks]
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def get_dialog_events(id):
|
|
109
|
+
with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
|
|
110
|
+
return json.load(reader)["Events"]
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def get_dialog_events(id):
|
|
114
|
+
with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
|
|
115
|
+
return json.load(reader)["Events"]
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def get_dialog_user_instructions(id):
|
|
119
|
+
def get_user_n_turns_before(turn_ix, events):
|
|
120
|
+
return len([e for e in events[:turn_ix]
|
|
121
|
+
if e["Agent"] == "User" and e["Action"] == "utter"])
|
|
122
|
+
events = STAR.get_dialog_events(id)
|
|
123
|
+
return {get_user_n_turns_before(ix, events): e["Text"]
|
|
124
|
+
for ix, e in enumerate(events)
|
|
125
|
+
if e["Action"] == "instruct" and e["Agent"] == "UserGuide"}
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def get_dialog_graphs_and_responses(id):
|
|
129
|
+
return STAR.get_dialog_graphs(id), STAR.get_dialog_responses(id)
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def get_scenario_description(scenario):
|
|
133
|
+
# Let's generate the graph description for each task:
|
|
134
|
+
flowcharts = ""
|
|
135
|
+
for task in scenario["WizardCapabilities"]:
|
|
136
|
+
task_name = task["Task"]
|
|
137
|
+
flowcharts += f"""
|
|
138
|
+
The graph for the task '{task_name}' with domain '{task['Domain']}' is:
|
|
139
|
+
```dot
|
|
140
|
+
{STAR.read_graph(task_name)}
|
|
141
|
+
```
|
|
142
|
+
and one example responses for each node is provided in the following json:
|
|
143
|
+
```json
|
|
144
|
+
{STAR.read_graph_responses(task_name)}
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
---
|
|
148
|
+
"""
|
|
149
|
+
# Finally, let's return the scenario object and natural language description for it.
|
|
150
|
+
return f"""The conversation is between a User and a AI assistant in the following domains: {', '.join(scenario['Domains'])}.
|
|
151
|
+
|
|
152
|
+
The User instructions are: {scenario['UserTask']}
|
|
153
|
+
The AI assistant instructions are: {scenario['WizardTask']}
|
|
154
|
+
|
|
155
|
+
In addition, the AI assistant is instructed to follow specific flowcharts to address the tasks. Flowcharts are defined as graph described using DOT.
|
|
156
|
+
The actual DOT for the current tasks are:
|
|
157
|
+
{flowcharts}
|
|
158
|
+
|
|
159
|
+
Finally, the following should be considered regarding the conversation:
|
|
160
|
+
1. {"The conversation follows the 'happy path', meaning the conversations goes according to what it is described in the flowcharts"
|
|
161
|
+
if scenario['Happy'] else
|
|
162
|
+
"The conversation does NOT follow a 'happy path', meaning something happend to the user to change its mind or something happend "
|
|
163
|
+
"in the environment for the conversation to not flow as expected, as described in the flowchart"}.
|
|
164
|
+
2. {"The user is calling to perform multiple tasks, involving all the tasks defined as flowcharts above (" + ', '.join(task['Task'] for task in scenario['WizardCapabilities']) + ")"
|
|
165
|
+
if scenario['MultiTask'] else
|
|
166
|
+
"The user is calling to perform only the defined task (" + scenario['WizardCapabilities'][0]['Task'] + "), nothing else"}.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def get_dialog_scenario_description(id):
|
|
171
|
+
scenario = STAR.get_dialog_scenario(id)
|
|
172
|
+
return scenario, STAR.get_scenario_description(scenario)
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def get_user_persona_for_scenario(scenario):
|
|
176
|
+
dialogue_details = f"""
|
|
177
|
+
The following should be considered regarding the conversation:
|
|
178
|
+
1. {"The conversation follows a 'happy path', meaning the conversations goes smoothly without any unexpected behavior"
|
|
179
|
+
if scenario['Happy'] else
|
|
180
|
+
"The conversation does NOT follow a 'happy path', meaning you have to simulate something happend in the middle of the conversation, "
|
|
181
|
+
"perhaps you changed your mind at some point or something external happend in the environment for the conversation to not flow as expected"}.
|
|
182
|
+
2. {"The conversation involves multiple tasks, that is, you want the assistant to perform multiple tasks (" + ', '.join(task['Task'] for task in scenario['WizardCapabilities']) + "), not just one."
|
|
183
|
+
if scenario['MultiTask'] else
|
|
184
|
+
"The conversation involves only one task you were instructed to (" + scenario['WizardCapabilities'][0]['Task'] + "), nothing else"}"""
|
|
185
|
+
|
|
186
|
+
return Persona(
|
|
187
|
+
role=f"user calling a AI assistant that can perform multiple tasks in the following domains: {', '.join(scenario['Domains'])}.\n" + dialogue_details,
|
|
188
|
+
circumstances=scenario["UserTask"],
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def get_flowchart_description_for_scenario(scenario):
|
|
193
|
+
flowcharts = ""
|
|
194
|
+
for task in scenario["WizardCapabilities"]:
|
|
195
|
+
task_name = task["Task"]
|
|
196
|
+
flowcharts += f"""
|
|
197
|
+
## {task_name} ({task['Domain']})
|
|
198
|
+
|
|
199
|
+
The flowchart described as an action transition graph for the task '{task_name}' with domain '{task['Domain']}' is:
|
|
200
|
+
```dot
|
|
201
|
+
{STAR.read_graph(task_name)}
|
|
202
|
+
```
|
|
203
|
+
Response example for each action is provided in the following json:
|
|
204
|
+
```json
|
|
205
|
+
{STAR.read_graph_responses(task_name)}
|
|
206
|
+
```
|
|
207
|
+
where UPPERCASE words above are just example placeholders. You MUST fill in those with any coherent values in the actual conversation.
|
|
208
|
+
"""
|
|
209
|
+
return flowcharts
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def get_system_persona_for_scenario(scenario):
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
dialogue_details = f"""In the conversation, the AI assistant is instructed to follow specific action flowcharts to address the tasks. Flowcharts are defined as graph described using DOT.
|
|
216
|
+
The actual DOT for the current tasks are:
|
|
217
|
+
{STAR.get_flowchart_description_for_scenario(scenario)}
|
|
218
|
+
"""
|
|
219
|
+
return Persona(
|
|
220
|
+
role="AI assistant.\n" + dialogue_details,
|
|
221
|
+
circumstances=scenario['WizardTask'],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def get_agents_for_scenario(scenario, model_name):
|
|
226
|
+
user = PersonaAgent(model_name,
|
|
227
|
+
STAR.get_user_persona_for_scenario(scenario),
|
|
228
|
+
name="User",
|
|
229
|
+
can_finish=True)
|
|
230
|
+
|
|
231
|
+
system = PersonaAgent(model_name,
|
|
232
|
+
STAR.get_system_persona_for_scenario(scenario),
|
|
233
|
+
name="System")
|
|
234
|
+
|
|
235
|
+
return system, user
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
def get_agents_from_dialogue(id, model_name:str, set_first_utterance: bool = False):
|
|
239
|
+
scenario = STAR.get_dialog_scenario(id)
|
|
240
|
+
system, user = STAR.get_agents_for_scenario(scenario, model_name)
|
|
241
|
+
|
|
242
|
+
if set_first_utterance:
|
|
243
|
+
first_turn = STAR.get_dialog_first_turn(id)
|
|
244
|
+
if first_turn.speaker == "Wizard":
|
|
245
|
+
system.set_first_utterances(first_turn.text)
|
|
246
|
+
else:
|
|
247
|
+
system.set_first_utterances("Hello, how can I help?")
|
|
248
|
+
|
|
249
|
+
return system, user
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def get_agents_from_dialogue_with_orchestration(id, model_name:str, set_first_utterance: bool = False):
|
|
253
|
+
system, user = STAR.get_agents_from_dialogue(id, model_name, set_first_utterance)
|
|
254
|
+
|
|
255
|
+
graphs, responses = STAR.get_dialog_graphs_and_responses(id)
|
|
256
|
+
response_action_orchestrator = SimpleResponseOrchestrator(responses[0], graph=graphs[0])
|
|
257
|
+
instr_list_orchestrator = InstructionListOrchestrator(
|
|
258
|
+
STAR.get_dialog_user_instructions(id),
|
|
259
|
+
persistent=True
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return system | response_action_orchestrator, user | instr_list_orchestrator
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
from langchain_ollama.chat_models import ChatOllama
|
|
5
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from print_color import print
|
|
10
|
+
from typing import Union, List
|
|
11
|
+
from langchain_ollama.chat_models import ChatOllama
|
|
12
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
13
|
+
|
|
14
|
+
from . import Dialog, Turn
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LLMDialogOutput(BaseModel):
|
|
18
|
+
dialog: List[Turn]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# TODO: create a BaseDialogGenerator, and also PersonaDialogGenerator that takes personas objects as in multi-agent.
|
|
22
|
+
class DialogGenerator:
|
|
23
|
+
def __init__(self, model: Union[ChatOllama, str], dialogue_details: str, output_format: Union[dict, BaseModel] = LLMDialogOutput, scenario: dict = None):
|
|
24
|
+
"""Optional `scenario` to populate the "scenario" field of the output, if not provided, `dialogue_details` content will be used."""
|
|
25
|
+
|
|
26
|
+
if not output_format or type(output_format) == dict:
|
|
27
|
+
output_format_schema = output_format
|
|
28
|
+
self.output_format = None
|
|
29
|
+
else:
|
|
30
|
+
output_format_schema = output_format.model_json_schema()
|
|
31
|
+
self.output_format = output_format
|
|
32
|
+
|
|
33
|
+
if type(model) == str:
|
|
34
|
+
self.llm = ChatOllama(model=model,
|
|
35
|
+
format=output_format_schema,
|
|
36
|
+
temperature=0.8,
|
|
37
|
+
seed=13)
|
|
38
|
+
else:
|
|
39
|
+
self.llm = model
|
|
40
|
+
if output_format:
|
|
41
|
+
self.llm.format = output_format
|
|
42
|
+
|
|
43
|
+
self.model_name = model
|
|
44
|
+
self.set(dialogue_details, scenario)
|
|
45
|
+
|
|
46
|
+
def generate(self, seed: int = None, id: int = None):
|
|
47
|
+
self.llm.seed = seed if seed is not None else random.getrandbits(32)
|
|
48
|
+
|
|
49
|
+
# hack to avoid seed bug in prompt cache (to force a new cache, related to https://github.com/ollama/ollama/issues/5321)
|
|
50
|
+
_ = self.llm.num_predict
|
|
51
|
+
self.llm.num_predict = 1
|
|
52
|
+
self.llm.invoke(self.messages)
|
|
53
|
+
self.llm.num_predict = _
|
|
54
|
+
|
|
55
|
+
dialogue = self.llm.invoke(self.messages).content
|
|
56
|
+
|
|
57
|
+
if not self.output_format:
|
|
58
|
+
return dialogue
|
|
59
|
+
else:
|
|
60
|
+
llm_output = self.output_format.model_validate(json.loads(dialogue))
|
|
61
|
+
|
|
62
|
+
if self.output_format is LLMDialogOutput:
|
|
63
|
+
return Dialog(dialogId=id if id else None,
|
|
64
|
+
model=self.model_name,
|
|
65
|
+
seed=self.llm.seed,
|
|
66
|
+
scenario=self.scenario if self.scenario else self.dialogue_details,
|
|
67
|
+
turns=llm_output.dialog)
|
|
68
|
+
else:
|
|
69
|
+
return llm_output
|
|
70
|
+
|
|
71
|
+
def set(self, dialogue_details: str, scenario:dict=None):
|
|
72
|
+
self.scenario = scenario
|
|
73
|
+
self.dialogue_details = dialogue_details
|
|
74
|
+
self.messages = [
|
|
75
|
+
SystemMessage(
|
|
76
|
+
"You are a knowledgeable and useful AI assistant that can write natural conversations by role paying different speakers."
|
|
77
|
+
"The output should be a full dialogue, from begining (greetings) to end (bye bye messages)."
|
|
78
|
+
),
|
|
79
|
+
HumanMessage(content=dialogue_details)
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
__call__ = generate
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
import inspect
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from time import time
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import List, Union, Dict, Optional
|
|
9
|
+
from sentence_transformers import SentenceTransformer
|
|
10
|
+
from langchain_core.messages import SystemMessage, AIMessage
|
|
11
|
+
|
|
12
|
+
from . import Turn, Event, Instruction
|
|
13
|
+
from .util import make_serializable
|
|
14
|
+
# from .personas import PersonaAgent
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseOrchestrator(ABC):
|
|
18
|
+
_target = None
|
|
19
|
+
_event_label = None
|
|
20
|
+
_persistent = False
|
|
21
|
+
|
|
22
|
+
def __init__(self, target_agent = None, persistent: bool = None, event_label: str = None):
|
|
23
|
+
self._target = target_agent
|
|
24
|
+
self._persistent = persistent
|
|
25
|
+
self._event_label = event_label
|
|
26
|
+
|
|
27
|
+
def __call__(self):
|
|
28
|
+
dialog = self.__get_current_dialog()
|
|
29
|
+
return self.instruct(dialog, dialog[-1].text if dialog and dialog[-1].speaker != self._target.get_name() else "")
|
|
30
|
+
|
|
31
|
+
def __str__(self) -> str:
|
|
32
|
+
data = self.json()
|
|
33
|
+
attrs = " ".join(f"{key}={value}" for key, value in data["args"].items())
|
|
34
|
+
return f"{data['name']}({attrs})"
|
|
35
|
+
|
|
36
|
+
def __get_current_dialog(self) -> List[Turn]:
|
|
37
|
+
return [Turn(speaker=self._target.get_name() if type(message) == AIMessage else None, text=message.content)
|
|
38
|
+
for message in self._target.memory if type(message) != SystemMessage]
|
|
39
|
+
|
|
40
|
+
def _set_target_agent(self, agent): # target: PersonaAgent
|
|
41
|
+
self._target = agent
|
|
42
|
+
|
|
43
|
+
def json(self, string: bool = False, indent: int =None):
|
|
44
|
+
sig = inspect.signature(self.__init__)
|
|
45
|
+
data = {"name": type(self).__name__,
|
|
46
|
+
"args": {key: self.__dict__[key] for key in sig.parameters
|
|
47
|
+
if key in self.__dict__ and self.__dict__[key] is not None}}
|
|
48
|
+
make_serializable(data["args"])
|
|
49
|
+
return json.dumps(data, indent=indent) if string else data
|
|
50
|
+
|
|
51
|
+
def get_event_label(self) -> str:
|
|
52
|
+
return self._event_label if self._event_label else type(self).__name__
|
|
53
|
+
|
|
54
|
+
def get_target_agent(self):
|
|
55
|
+
return self._target
|
|
56
|
+
|
|
57
|
+
def is_persistent(self):
|
|
58
|
+
return self._persistent
|
|
59
|
+
|
|
60
|
+
def set_persistent(self, value: bool):
|
|
61
|
+
self._persistent = value
|
|
62
|
+
|
|
63
|
+
def agent_response_lookahead(self):
|
|
64
|
+
return self._target.response_lookahead()
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def reset(self):
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class BasePersistentOrchestrator(BaseOrchestrator): #, ABC):
|
|
75
|
+
_persistent = True
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def reset(self):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class SimpleReflexOrchestrator(BaseOrchestrator):
|
|
86
|
+
def __init__(self, condition: callable, instruction: str, persistent: bool = False, event_label: str = None):
|
|
87
|
+
super().__init__(persistent=persistent, event_label=event_label)
|
|
88
|
+
self.condition = condition
|
|
89
|
+
self.instruction = instruction
|
|
90
|
+
|
|
91
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
92
|
+
if self.condition(utterance):
|
|
93
|
+
return self.instruction
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class LengthOrchestrator(BaseOrchestrator):
|
|
97
|
+
def __init__(self, min: int = None, max: int = None, persistent: bool = False, event_label: str = None):
|
|
98
|
+
super().__init__(persistent=persistent, event_label=event_label)
|
|
99
|
+
self.max = max
|
|
100
|
+
self.min = min
|
|
101
|
+
|
|
102
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
103
|
+
if self.min is not None and len(dialog) < self.min and len(dialog) > 1:
|
|
104
|
+
return "Make sure you DO NOT finish the conversation, keep it going!"
|
|
105
|
+
elif self.max and len(dialog) >= self.max - 1: # + answer
|
|
106
|
+
return "Now FINISH the conversation AS SOON AS possible, if possible, RIGHT NOW!"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ChangeMindOrchestrator(BaseOrchestrator):
|
|
110
|
+
def __init__(self, probability: float = 0.3,
|
|
111
|
+
reasons: Union[str, List[str]] = None,
|
|
112
|
+
max_times: int = 1,
|
|
113
|
+
persistent: bool = False,
|
|
114
|
+
event_label: str = None):
|
|
115
|
+
super().__init__(persistent=persistent, event_label=event_label)
|
|
116
|
+
self.probability = probability
|
|
117
|
+
self.reasons = [reasons] if type(reasons) == str else reasons
|
|
118
|
+
self.max_times = max_times
|
|
119
|
+
self.times = 0
|
|
120
|
+
|
|
121
|
+
def reset(self):
|
|
122
|
+
self.times = 0
|
|
123
|
+
|
|
124
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
125
|
+
if self.max_times and self.times >= self.max_times:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
if random.random() <= self.probability:
|
|
129
|
+
self.times += 1
|
|
130
|
+
instruction = "Change your mind completely, in your next utterance, suggest something completely different!"
|
|
131
|
+
if self.reasons:
|
|
132
|
+
instruction += f" **Reason:** {random.choice(self.reasons)}."
|
|
133
|
+
return instruction
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class SimpleResponseOrchestrator(BaseOrchestrator):
|
|
137
|
+
def __init__(self,
|
|
138
|
+
responses: List[Union[str, Dict[str, str]]],
|
|
139
|
+
graph: Dict[str, str] = None,
|
|
140
|
+
# sbert_model: str = "sentence-transformers/LaBSE",
|
|
141
|
+
sbert_model: str = "sergioburdisso/dialog2flow-joint-bert-base",
|
|
142
|
+
top_k: int = 5):
|
|
143
|
+
|
|
144
|
+
self.sent_encoder = SentenceTransformer(sbert_model)
|
|
145
|
+
self.responses = responses
|
|
146
|
+
self.top_k = top_k
|
|
147
|
+
|
|
148
|
+
if type(responses) == dict:
|
|
149
|
+
self.resp_utts = np.array([resp for resp in responses.values()])
|
|
150
|
+
self.resp_acts = np.array([act for act in responses.keys()])
|
|
151
|
+
self.graph = graph
|
|
152
|
+
else:
|
|
153
|
+
self.resp_utts = np.array(responses)
|
|
154
|
+
self.resp_acts = None
|
|
155
|
+
self.graph = None
|
|
156
|
+
|
|
157
|
+
self.resp_utt_embs = self.sent_encoder.encode(self.resp_utts)
|
|
158
|
+
|
|
159
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
160
|
+
agent = self.get_target_agent()
|
|
161
|
+
|
|
162
|
+
agent_last_turn = None
|
|
163
|
+
if self.graph and dialog:
|
|
164
|
+
for turn in dialog[::-1]:
|
|
165
|
+
if turn.speaker == agent.get_name():
|
|
166
|
+
agent_last_turn = turn.text
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
response = agent_last_turn if agent_last_turn else agent.response_lookahead()
|
|
170
|
+
|
|
171
|
+
events = [Event(agent=agent.get_name(),
|
|
172
|
+
action="request_suggestions",
|
|
173
|
+
actionLabel=self.get_event_label(),
|
|
174
|
+
text=f'Previous response: "{response}"' if agent_last_turn else f'Lookahead response: "{response}"',
|
|
175
|
+
timestamp=int(time()))]
|
|
176
|
+
|
|
177
|
+
sims = self.sent_encoder.similarity(self.sent_encoder.encode(response), self.resp_utt_embs)[0]
|
|
178
|
+
top_k_ixs = sims.argsort(descending=True)[:self.top_k]
|
|
179
|
+
|
|
180
|
+
if self.resp_acts is None:
|
|
181
|
+
instruction = ("If applicable, try to pick your next response from the following list: " +
|
|
182
|
+
"; ".join(f'({ix + 1}) {resp}' for ix, resp in enumerate(self.resp_utts[top_k_ixs])))
|
|
183
|
+
else:
|
|
184
|
+
next_actions = self.resp_acts[top_k_ixs].tolist()
|
|
185
|
+
events.append(Event(agent=agent.get_name(),
|
|
186
|
+
action="request_suggestions",
|
|
187
|
+
actionLabel=self.get_event_label(),
|
|
188
|
+
text="Actions for the response: "+ ", ".join(action for action in next_actions),
|
|
189
|
+
timestamp=int(time())))
|
|
190
|
+
if agent_last_turn:
|
|
191
|
+
next_actions = [self.graph[action] if action in self.graph else action
|
|
192
|
+
for action in next_actions]
|
|
193
|
+
events.append(Event(agent=agent.get_name(),
|
|
194
|
+
action="request_suggestions",
|
|
195
|
+
actionLabel=self.get_event_label(),
|
|
196
|
+
text="Graph next actions: " + ", ".join(action for action in next_actions),
|
|
197
|
+
timestamp=int(time())))
|
|
198
|
+
|
|
199
|
+
# TODO: remove repeated actions! (make it a set()?)
|
|
200
|
+
next_actions = [action for action in next_actions if action in self.responses]
|
|
201
|
+
instruction = ("If applicable, pick your next response from the following action list in order of importance: " +
|
|
202
|
+
"; ".join(f'({ix + 1}) Action: {action}. Response: "{self.responses[action]}"' for ix, action in enumerate(next_actions)))
|
|
203
|
+
|
|
204
|
+
return Instruction(text=instruction, events=events)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class InstructionListOrchestrator(BaseOrchestrator):
|
|
208
|
+
def __init__(self,
|
|
209
|
+
instructions: List[Union[str, Dict[int, str]]],
|
|
210
|
+
persistent: bool = False):
|
|
211
|
+
super().__init__(persistent=persistent)
|
|
212
|
+
self.instructions = instructions
|
|
213
|
+
|
|
214
|
+
def instruct(self, dialog: List[Turn], utterance: str) -> str:
|
|
215
|
+
agent = self.get_target_agent()
|
|
216
|
+
|
|
217
|
+
if dialog:
|
|
218
|
+
current_user_len = len([t for t in dialog if t.speaker == agent.get_name()])
|
|
219
|
+
else:
|
|
220
|
+
current_user_len = 0
|
|
221
|
+
|
|
222
|
+
if (type(self.instructions) == dict and current_user_len in self.instructions) or \
|
|
223
|
+
(type(self.instructions) == list and current_user_len < len(self.instructions)):
|
|
224
|
+
return self.instructions[current_user_len]
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
from time import time
|
|
5
|
+
from tqdm.auto import trange
|
|
6
|
+
from print_color import print
|
|
7
|
+
from typing import List, Union, Optional
|
|
8
|
+
from langchain_ollama.chat_models import ChatOllama
|
|
9
|
+
# from langchain_core.prompts import PromptTemplate
|
|
10
|
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
11
|
+
|
|
12
|
+
from . import Dialog, Turn, Event, Instruction
|
|
13
|
+
from .orchestrators import BaseOrchestrator
|
|
14
|
+
from .util import make_serializable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Meta(type):
|
|
18
|
+
def __init__(cls, name, bases, dct):
|
|
19
|
+
def auto__call__init__(self, *a, **kw):
|
|
20
|
+
for base in cls.__bases__:
|
|
21
|
+
base.__init__(self, *a, **kw)
|
|
22
|
+
cls.__init__child_(self, *a, **kw)
|
|
23
|
+
cls.__init__child_ = cls.__init__
|
|
24
|
+
cls.__init__ = auto__call__init__
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BasePersona(metaclass=Meta):
|
|
28
|
+
def __init__(self, **kwargs):
|
|
29
|
+
self.__dict__.update(kwargs)
|
|
30
|
+
|
|
31
|
+
def description(self) -> str:
|
|
32
|
+
return "\n".join(f"Your {key}: {value}" for key, value in self.__dict__.items())
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
return self.description()
|
|
36
|
+
|
|
37
|
+
def json(self, string: bool = False, indent=None):
|
|
38
|
+
data = self.__dict__.copy()
|
|
39
|
+
make_serializable(data)
|
|
40
|
+
return json.dumps(data, indent=indent) if string else data
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Persona(BasePersona):
|
|
44
|
+
name: str = ""
|
|
45
|
+
role: str = ""
|
|
46
|
+
background: str = ""
|
|
47
|
+
personality: str = ""
|
|
48
|
+
circumstances: str = ""
|
|
49
|
+
rules: str = ""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class PersonaAgent:
|
|
53
|
+
|
|
54
|
+
STOP_WORD = "STOP"
|
|
55
|
+
STOP_WORD_TEXT = "(bye bye!)"
|
|
56
|
+
|
|
57
|
+
def __init__(self,
|
|
58
|
+
model : Union[str, ChatOllama],
|
|
59
|
+
persona: BasePersona = Persona(),
|
|
60
|
+
name: str = None,
|
|
61
|
+
dialogue_details: str = "",
|
|
62
|
+
response_details: str = "responses SHOULD NOT be too long and wordy, should be approximately one utterance long",
|
|
63
|
+
system_prompt: str = None,
|
|
64
|
+
can_finish: bool = False,
|
|
65
|
+
orchestrators: Union[BaseOrchestrator, List[BaseOrchestrator]] = None,
|
|
66
|
+
scenario: Union[dict, str] = None):
|
|
67
|
+
|
|
68
|
+
if not system_prompt:
|
|
69
|
+
if can_finish:
|
|
70
|
+
conversation_end_instructions = f"To finish the conversation you first have to say good bye and immediately after you **MUST** output '{self.STOP_WORD}' to indicate it is the end of it."
|
|
71
|
+
else:
|
|
72
|
+
conversation_end_instructions = "When the user finish the conversation you should say good bye and also finish the conversation"
|
|
73
|
+
|
|
74
|
+
# system_prompt = prompt_template.format(role=role, ...)
|
|
75
|
+
system_prompt = f"""Role play as a character that is described by the persona defined in the following lines. You always stay in character.
|
|
76
|
+
[[ ## BEGING PERSONA ## ]]
|
|
77
|
+
{persona}
|
|
78
|
+
[[ ## END PERSONA ## ]]
|
|
79
|
+
---
|
|
80
|
+
{"Details about the overall dialogue: " + dialogue_details if dialogue_details else ""}
|
|
81
|
+
{"Details about your responses: " + response_details if response_details else ""}
|
|
82
|
+
Finally, remember:
|
|
83
|
+
1. You always stay on character. You are the character described above.
|
|
84
|
+
2. Your first utterance / turn MUST always be a short generic greeting (e.g. "Hello, how are you?", "Hi!", "hey! what's up?", etc.), and nothing else, wait for a reply before start with the actual conversation.
|
|
85
|
+
3. {conversation_end_instructions}."""
|
|
86
|
+
|
|
87
|
+
if type(model) == str:
|
|
88
|
+
# TODO: ChatHuggingFace
|
|
89
|
+
self.llm = ChatOllama(model=model,
|
|
90
|
+
temperature=0.8,
|
|
91
|
+
seed=13)
|
|
92
|
+
else:
|
|
93
|
+
self.llm = model
|
|
94
|
+
self.memory = [SystemMessage(system_prompt)]
|
|
95
|
+
|
|
96
|
+
self.name = name if name else (persona.name if hasattr(persona, "name") else None)
|
|
97
|
+
self.persona = persona
|
|
98
|
+
self.model_name = str(self.llm)
|
|
99
|
+
self.first_utterances = None
|
|
100
|
+
self.finished = False
|
|
101
|
+
self.scenario = scenario
|
|
102
|
+
self.orchestrators = None
|
|
103
|
+
self.add_orchestrators(orchestrators)
|
|
104
|
+
|
|
105
|
+
def __call__(self, utterance: str = "", return_events: bool = False) -> str:
|
|
106
|
+
if self.finished:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
if utterance:
|
|
110
|
+
self.memory.append(HumanMessage(content=utterance))
|
|
111
|
+
|
|
112
|
+
if return_events: events = []
|
|
113
|
+
if self.orchestrators:
|
|
114
|
+
for orchestrator in self.orchestrators:
|
|
115
|
+
instruction = orchestrator()
|
|
116
|
+
if instruction:
|
|
117
|
+
|
|
118
|
+
if type(instruction) == Instruction:
|
|
119
|
+
if return_events and instruction.events:
|
|
120
|
+
if type(instruction.events) == Event: events.append(instruction.events)
|
|
121
|
+
else: events.extend(instruction.events)
|
|
122
|
+
instruction = instruction.text
|
|
123
|
+
|
|
124
|
+
persist = orchestrator.is_persistent()
|
|
125
|
+
self.instruct(instruction, persist=persist)
|
|
126
|
+
if return_events:
|
|
127
|
+
events.append(Event(agent=self.get_name(),
|
|
128
|
+
action="instruct" + ("-persist" if persist else ""),
|
|
129
|
+
actionLabel=orchestrator.get_event_label(),
|
|
130
|
+
text=instruction,
|
|
131
|
+
timestamp=int(time())))
|
|
132
|
+
|
|
133
|
+
if len(self.memory) <= 1 and self.first_utterances:
|
|
134
|
+
response = random.choice(self.first_utterances) if type(self.first_utterances) == list else self.first_utterances
|
|
135
|
+
response = AIMessage(content=response)
|
|
136
|
+
else:
|
|
137
|
+
response = self.llm.invoke(self.memory)
|
|
138
|
+
|
|
139
|
+
if self.orchestrators:
|
|
140
|
+
self.memory[:] = [msg for msg in self.memory
|
|
141
|
+
if not (msg.response_metadata and "persist" in msg.response_metadata and not msg.response_metadata["persist"])]
|
|
142
|
+
self.memory.append(response)
|
|
143
|
+
|
|
144
|
+
response = response.content
|
|
145
|
+
if self.STOP_WORD in response:
|
|
146
|
+
response = response.replace(self.STOP_WORD, self.STOP_WORD_TEXT).strip()
|
|
147
|
+
self.memory[-1].content = self.memory[-1].content.replace(self.STOP_WORD, "").strip()
|
|
148
|
+
self.finished = True
|
|
149
|
+
|
|
150
|
+
if return_events:
|
|
151
|
+
if response:
|
|
152
|
+
events.append(Event(agent=self.get_name(),
|
|
153
|
+
action="utter",
|
|
154
|
+
text=response,
|
|
155
|
+
timestamp=int(time())))
|
|
156
|
+
return events
|
|
157
|
+
else:
|
|
158
|
+
return response if response else ""
|
|
159
|
+
|
|
160
|
+
def __or__(self, orchestrator: Union[BaseOrchestrator, List[BaseOrchestrator]]):
|
|
161
|
+
self.add_orchestrators(orchestrator)
|
|
162
|
+
return self
|
|
163
|
+
|
|
164
|
+
def response_lookahead(self, utterance: str = None):
|
|
165
|
+
if not utterance:
|
|
166
|
+
return self.llm.invoke(self.memory).content
|
|
167
|
+
return self.llm.invoke(self.memory + [HumanMessage(utterance)]).content
|
|
168
|
+
|
|
169
|
+
def add_orchestrators(self, orchestrators):
|
|
170
|
+
if not orchestrators:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
if self.orchestrators == None:
|
|
174
|
+
self.orchestrators = []
|
|
175
|
+
|
|
176
|
+
if isinstance(orchestrators, BaseOrchestrator):
|
|
177
|
+
orchestrators = [orchestrators]
|
|
178
|
+
|
|
179
|
+
self.orchestrators.extend(orchestrators)
|
|
180
|
+
|
|
181
|
+
for orchestrator in orchestrators:
|
|
182
|
+
orchestrator._set_target_agent(self)
|
|
183
|
+
|
|
184
|
+
def clear_orchestrators(self):
|
|
185
|
+
self.orchestrators = None
|
|
186
|
+
|
|
187
|
+
def instruct(self, instruction: str, persist: bool = False):
|
|
188
|
+
self.memory.append(SystemMessage(instruction, response_metadata={"persist": persist}))
|
|
189
|
+
|
|
190
|
+
def set_first_utterances(self, utterances: Union[str, List[str]]):
|
|
191
|
+
self.first_utterances = utterances
|
|
192
|
+
|
|
193
|
+
def get_name(self):
|
|
194
|
+
return self.name
|
|
195
|
+
|
|
196
|
+
def get_prompt(self):
|
|
197
|
+
return self.memory[0].content
|
|
198
|
+
|
|
199
|
+
def json(self, string: bool = False, indent=None):
|
|
200
|
+
data = {}
|
|
201
|
+
if self.name:
|
|
202
|
+
data["name"] = self.name
|
|
203
|
+
data["model_name"] = self.model_name
|
|
204
|
+
if self.first_utterances:
|
|
205
|
+
data["first_utterances"] = self.first_utterances
|
|
206
|
+
data["persona"] = self.persona.json()
|
|
207
|
+
if self.orchestrators:
|
|
208
|
+
data["persona"]["orchestrators"] = [orc.json() for orc in self.orchestrators]
|
|
209
|
+
return json.dumps(data, indent=indent) if string else data
|
|
210
|
+
|
|
211
|
+
def reset(self, seed:int = None):
|
|
212
|
+
self.memory[:] = self.memory[:1]
|
|
213
|
+
self.finished = False
|
|
214
|
+
self.llm.seed = seed
|
|
215
|
+
|
|
216
|
+
if self.orchestrators:
|
|
217
|
+
for orchestrator in self.orchestrators:
|
|
218
|
+
orchestrator.reset()
|
|
219
|
+
|
|
220
|
+
# hack to avoid seed bug in prompt cache (to force a new cache, related to https://github.com/ollama/ollama/issues/5321)
|
|
221
|
+
_ = self.llm.num_predict
|
|
222
|
+
self.llm.num_predict = 1
|
|
223
|
+
self.llm.invoke(self.memory)
|
|
224
|
+
self.llm.num_predict = _
|
|
225
|
+
|
|
226
|
+
def dialog_with(self,
|
|
227
|
+
persona: "PersonaAgent",
|
|
228
|
+
max_iterations: int = 20,
|
|
229
|
+
id: int = None,
|
|
230
|
+
seed: int = None,
|
|
231
|
+
keep_bar: bool = True):
|
|
232
|
+
seed = seed if seed is not None else random.getrandbits(32)
|
|
233
|
+
|
|
234
|
+
random.seed(seed)
|
|
235
|
+
self.reset(seed)
|
|
236
|
+
persona.reset(seed)
|
|
237
|
+
|
|
238
|
+
dialog = []
|
|
239
|
+
events = []
|
|
240
|
+
|
|
241
|
+
utter = None
|
|
242
|
+
completion = False
|
|
243
|
+
tqdm_iterator = trange(max_iterations, desc="Dialogue", leave=keep_bar)
|
|
244
|
+
for _ in tqdm_iterator:
|
|
245
|
+
utt_events = self(utter, return_events=True)
|
|
246
|
+
|
|
247
|
+
if utt_events and utt_events[-1].action == "utter":
|
|
248
|
+
utter = utt_events[-1].text
|
|
249
|
+
utt_events[-1].text = utter.replace(self.STOP_WORD_TEXT, "").strip()
|
|
250
|
+
if not utt_events[-1].text: break
|
|
251
|
+
else:
|
|
252
|
+
completion = True
|
|
253
|
+
break
|
|
254
|
+
|
|
255
|
+
dialog.append(Turn(
|
|
256
|
+
speaker=self.get_name() if self.get_name() else "Me",
|
|
257
|
+
text=utt_events[-1].text
|
|
258
|
+
))
|
|
259
|
+
events.extend(utt_events)
|
|
260
|
+
|
|
261
|
+
utt_events = persona(utter, return_events=True)
|
|
262
|
+
if utt_events and utt_events[-1].action == "utter":
|
|
263
|
+
utter = utt_events[-1].text
|
|
264
|
+
utt_events[-1].text = utter.replace(self.STOP_WORD_TEXT, "").strip()
|
|
265
|
+
if not utt_events[-1].text: break
|
|
266
|
+
else:
|
|
267
|
+
completion = True
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
dialog.append(Turn(
|
|
271
|
+
speaker=persona.get_name() if persona.get_name() else "Other",
|
|
272
|
+
text=utt_events[-1].text
|
|
273
|
+
))
|
|
274
|
+
events.extend(utt_events)
|
|
275
|
+
|
|
276
|
+
if not keep_bar:
|
|
277
|
+
tqdm_iterator.container.close()
|
|
278
|
+
|
|
279
|
+
if self.scenario:
|
|
280
|
+
scenario = self.scenario
|
|
281
|
+
else:
|
|
282
|
+
scenario = {
|
|
283
|
+
"agents": [
|
|
284
|
+
self.json(),
|
|
285
|
+
persona.json()
|
|
286
|
+
]
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
return Dialog(
|
|
290
|
+
dialogId=id if id else None,
|
|
291
|
+
complete=completion, # incomplete if ran out of iterations (reached max_iteration number)
|
|
292
|
+
model=self.model_name,
|
|
293
|
+
seed=seed,
|
|
294
|
+
scenario=scenario,
|
|
295
|
+
turns=dialog,
|
|
296
|
+
events=events
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
talk_with = dialog_with
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sdialog
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Synthetic Dialogue Generation and Analysis
|
|
5
|
+
Author-email: Sergio Burdisso <sergio.burdisso@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/idiap/sdialog
|
|
8
|
+
Project-URL: Issues, https://github.com/idiap/sdialog/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.9
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: print-color
|
|
15
|
+
Requires-Dist: langchain
|
|
16
|
+
Requires-Dist: langchain-ollama
|
|
17
|
+
Requires-Dist: tqdm
|
|
18
|
+
Requires-Dist: plotly
|
|
19
|
+
Requires-Dist: sentence-transformers
|
|
20
|
+
Requires-Dist: pandas
|
|
21
|
+
Requires-Dist: tenacity
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: flake8
|
|
24
|
+
Requires-Dist: pytest
|
|
25
|
+
Requires-Dist: ollama
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# SDialog
|
|
29
|
+
|
|
30
|
+
Synthetic Dialogue Generation and Analysis
|
|
31
|
+
|
|
32
|
+
_(Comming soon)_
|
|
33
|
+
|
|
34
|
+
This package requires `Ollama` running is your system.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
requirements.txt
|
|
5
|
+
src/sdialog/__init__.py
|
|
6
|
+
src/sdialog/datasets.py
|
|
7
|
+
src/sdialog/generators.py
|
|
8
|
+
src/sdialog/orchestrators.py
|
|
9
|
+
src/sdialog/personas.py
|
|
10
|
+
src/sdialog/util.py
|
|
11
|
+
src/sdialog.egg-info/PKG-INFO
|
|
12
|
+
src/sdialog.egg-info/SOURCES.txt
|
|
13
|
+
src/sdialog.egg-info/dependency_links.txt
|
|
14
|
+
src/sdialog.egg-info/requires.txt
|
|
15
|
+
src/sdialog.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sdialog
|