sdialog 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sdialog-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Idiap Research Institute
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
sdialog-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: sdialog
3
+ Version: 0.0.1
4
+ Summary: Synthetic Dialogue Generation and Analysis
5
+ Author-email: Sergio Burdisso <sergio.burdisso@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/idiap/sdialog
8
+ Project-URL: Issues, https://github.com/idiap/sdialog/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.9
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: print-color
15
+ Requires-Dist: langchain
16
+ Requires-Dist: langchain-ollama
17
+ Requires-Dist: tqdm
18
+ Requires-Dist: plotly
19
+ Requires-Dist: sentence-transformers
20
+ Requires-Dist: pandas
21
+ Requires-Dist: tenacity
22
+ Requires-Dist: numpy
23
+ Requires-Dist: flake8
24
+ Requires-Dist: pytest
25
+ Requires-Dist: ollama
26
+ Dynamic: license-file
27
+
28
+ # SDialog
29
+
30
+ Synthetic Dialogue Generation and Analysis
31
+
32
+ _(Comming soon)_
33
+
34
+ This package requires `Ollama` running is your system.
@@ -0,0 +1,7 @@
1
+ # SDialog
2
+
3
+ Synthetic Dialogue Generation and Analysis
4
+
5
+ _(Comming soon)_
6
+
7
+ This package requires `Ollama` running is your system.
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["setuptools >= 77.0.3"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sdialog"
7
+ version = "0.0.1"
8
+ authors = [
9
+ { name="Sergio Burdisso", email="sergio.burdisso@gmail.com" },
10
+ ]
11
+ description = "Synthetic Dialogue Generation and Analysis"
12
+ readme = "README.md"
13
+ requires-python = ">=3.9"
14
+ dynamic = ["dependencies"]
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ license = "MIT"
20
+ license-files = ["LICEN[CS]E*"]
21
+
22
+ [tool.setuptools.dynamic]
23
+ dependencies = {file = ["requirements.txt"]}
24
+
25
+ [project.urls]
26
+ Homepage = "https://github.com/idiap/sdialog"
27
+ Issues = "https://github.com/idiap/sdialog/issues"
@@ -0,0 +1,12 @@
1
+ print-color
2
+ langchain
3
+ langchain-ollama
4
+ tqdm
5
+ plotly
6
+ sentence-transformers
7
+ pandas
8
+ tenacity
9
+ numpy
10
+ flake8
11
+ pytest
12
+ ollama
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,131 @@
1
+ import json
2
+
3
+ from pydantic import BaseModel
4
+ from typing import List, Union, Optional
5
+ from print_color import print
6
+
7
+ from .util import make_serializable
8
+
9
+
10
+ class Turn(BaseModel):
11
+ speaker: Optional[str]
12
+ text: str
13
+
14
+
15
+ class Event(BaseModel):
16
+ agent: Optional[str] = None # "user", "system"
17
+ action: str # "utter", "instruct"
18
+ actionLabel: Optional[str] = None # action label (e.g. type of instruct)
19
+ text: str # the content of the event
20
+ timestamp: int # timestemp
21
+
22
+
23
+ class Dialog(BaseModel):
24
+ formatVersion: Optional[str] = "0.0.5" # Version of the format
25
+ model: Optional[str] = None # the model used to generate the dialogue
26
+ seed: Optional[int] = None # the seed used to generated
27
+ dialogId: Optional[int] = None
28
+ complete: Optional[bool] = None
29
+ scenario: Optional[Union[dict, str]] = None # the scenario used to generated the dialogue
30
+ turns: List[Turn] # the list of turns of the conversation
31
+ events: Optional[List[Event]] = None
32
+
33
+ def __len__(self):
34
+ return len(self.turns)
35
+
36
+ def description(self, turn_template: str = "{speaker}: {text}"):
37
+ return "\n".join(turn_template.format(speaker=turn.speaker, text=turn.text.replace("\n", " "))
38
+ for turn in self.turns)
39
+
40
+ def json(self, string: bool = False, indent: int = None):
41
+ data = self.model_dump()
42
+ make_serializable(data)
43
+ return json.dumps(data, indent=indent) if string else data
44
+
45
+ def print(self, *a, **kw):
46
+ print_dialog(self, *a, **kw)
47
+
48
+ def to_file(self, path: str, type: str = "auto"): # type = "txt", "json" or "auto" which get's the type from the file extention
49
+ if type == "auto":
50
+ type = "json" if path.endswith(".json") else "txt"
51
+
52
+ with open(path, "w") as writer:
53
+ if type == "json":
54
+ writer.write(self.json(string=True))
55
+ else:
56
+ writer.write(self.description())
57
+
58
+ @staticmethod
59
+ def from_file(path: str, type: str = "auto"): # type = "txt", "json" or "auto" which get's the type from the file extention
60
+ if type == "auto":
61
+ type = "json" if path.endswith(".json") else "txt"
62
+
63
+ with open(path) as reader:
64
+ if type == "json":
65
+ return Dialog.model_validate(json.load(reader))
66
+
67
+ lines = reader.read().split("\n")
68
+
69
+ return Dialog(turns=[Turn(speaker=line[:line.index(":")].strip(),
70
+ text=line[line.index(":") + 1:].strip())
71
+ for line in lines if line])
72
+
73
+ # TODO: add from_dict as an alias of (so we don't have to use .model_validate())
74
+
75
+ __str__ = description
76
+
77
+
78
+ class Instruction(BaseModel):
79
+ text: str = None
80
+ events: Optional[Union[Event, List[Event]]] = None # extra events
81
+
82
+
83
+ def print_dialog(dialog: Union[Dialog, dict], scenario: bool = False, orchestration: bool = False):
84
+ if type(dialog) == dict:
85
+ dialog = Dialog.model_validate(dialog)
86
+
87
+ speaker_tag_colors = ["red", "blue", "yellow", "cyan", "green", "magenta", "purple"]
88
+ speaker_utt_colors = ["grey", "white"]
89
+ # speaker_utt_colors = ["black", "grey"]
90
+
91
+ if dialog.dialogId:
92
+ print(dialog.dialogId, tag="dialog_id", tag_color="purple", color="magenta", format="bold")
93
+ if dialog.complete:
94
+ print(dialog.complete, tag="complete", tag_color="purple", color="magenta", format="bold")
95
+ if dialog.model:
96
+ print(dialog.model, tag="model", tag_color="purple", color="magenta", format="bold")
97
+ if dialog.seed:
98
+ print(dialog.seed, tag="seed", tag_color="purple", color="magenta", format="bold")
99
+ if scenario and dialog.scenario:
100
+ print("", tag="scenario", tag_color="purple", color="magenta", format="bold")
101
+ if type(dialog.scenario) == str:
102
+ print(dialog.scenario, color="magenta")
103
+ else:
104
+ print(json.dumps(dialog.scenario, indent=2), color="magenta")
105
+
106
+ print("--- Dialogue Begins ---", color="magenta", format="bold")
107
+ speakers = sorted(list(set(turn.speaker for turn in dialog.turns)))
108
+ if orchestration:
109
+ dialog = dialog.model_copy()
110
+ dialog.turns = [Turn(speaker=e.agent, text=e.text) if e.action == "utter"
111
+ else (
112
+ Turn(speaker=e.agent, text=f"[pick_suggestion] {e.text}") if e.action == "pick_suggestion"
113
+ else
114
+ Turn(speaker=e.action, text=f"({e.agent}) {e.text}"))
115
+ for e in dialog.events]
116
+
117
+ for ix, turn in enumerate(dialog.turns):
118
+ speaker = turn.speaker
119
+
120
+ if speaker not in speakers:
121
+ tag_color = "yellow"
122
+ color = "purple"
123
+ else:
124
+ tag_color = speaker_tag_colors[speakers.index(speaker) % len(speaker_tag_colors)]
125
+ color = speaker_utt_colors[speakers.index(speaker) % len(speaker_utt_colors)]
126
+
127
+ print(turn.text,
128
+ tag=speaker,
129
+ tag_color=tag_color,
130
+ color=color)
131
+ print("--- Dialogue Ends ---", color="magenta", format="bold")
@@ -0,0 +1,262 @@
1
+ import os
2
+ import re
3
+ import json
4
+
5
+ from tqdm.auto import tqdm
6
+
7
+ from . import Dialog, Turn, Event
8
+ from .personas import Persona, PersonaAgent
9
+ from .orchestrators import InstructionListOrchestrator, SimpleResponseOrchestrator
10
+
11
+ class STAR:
12
+ _path = None
13
+ _speakers = ["User", "Wizard"]
14
+
15
+ @staticmethod
16
+ def set_path(path):
17
+ STAR._path = path
18
+
19
+ @staticmethod
20
+ def read_graph(task_name, as_dot: bool = True):
21
+ with open(os.path.join(STAR._path, f"tasks/{task_name}/{task_name}.json")) as reader:
22
+ if not as_dot:
23
+ return json.load(reader)["graph"]
24
+ dot_edges = ";\n".join(f" {a} -> {b}" for a,b in json.load(reader)["graph"].items())
25
+
26
+ return "digraph %s {\n%s\n}" % (task_name, dot_edges)
27
+
28
+ @staticmethod
29
+ def read_graph_responses(task_name, as_dict: bool = False):
30
+ with open(os.path.join(STAR._path, f"tasks/{task_name}/responses.json")) as reader:
31
+ responses = json.load(reader)
32
+ responses = {key:re.sub(r"{(.+?)(?::\w+?)?}", lambda m:m.group(1).upper(), value)
33
+ for key, value in responses.items()
34
+ if key != "out_of_scope"}
35
+ return responses if as_dict else json.dumps(responses, indent=2)
36
+
37
+ @staticmethod
38
+ def get_dialog(id):
39
+ with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
40
+ dialog = json.load(reader)
41
+
42
+ for e in dialog["Events"]:
43
+ if e["Agent"] == "Wizard":
44
+ e["Agent"] = "System"
45
+
46
+ return Dialog(
47
+ dialogId=id,
48
+ scenario=dialog["Scenario"],
49
+ turns=[Turn(speaker=e["Agent"], text=e["Text"])
50
+ for e in dialog["Events"]
51
+ if e["Action"] in ["utter", "pick_suggestion"]],
52
+ events=[Event(agent=e["Agent"],
53
+ action=e["Action"],
54
+ actionLabel=e["ActionLabel"] if "ActionLabel" in e else None,
55
+ text=e["Text"],
56
+ timestamp=e["UnixTime"])
57
+ for e in dialog["Events"]
58
+ if "Text" in e]
59
+ )
60
+
61
+ @staticmethod
62
+ def get_dialogs(domain: str = None, task_name: str = None, happy: bool = None, multitask: bool = None):
63
+ dialogs = []
64
+ for fname in tqdm(os.listdir(os.path.join(STAR._path, f"dialogues/")), desc="Reading dialogs", leave=False):
65
+ if not fname.endswith(".json"):
66
+ continue
67
+ dialog_id = int(os.path.splitext(fname)[0])
68
+ scenario = STAR.get_dialog_scenario(dialog_id)
69
+
70
+ if (domain is None or domain in scenario["Domains"]) and \
71
+ (happy is None or scenario["Happy"] == happy) and \
72
+ (multitask is None or scenario["MultiTask"] == multitask) and \
73
+ (task_name is None or any(capability["Task"] == task_name for capability in scenario["WizardCapabilities"])):
74
+ dialogs.append(STAR.get_dialog(dialog_id))
75
+ return dialogs
76
+
77
+ @staticmethod
78
+ def get_dialog_scenario(id):
79
+ with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
80
+ return json.load(reader)["Scenario"]
81
+
82
+ @staticmethod
83
+ def get_dialog_first_turn(id, speaker: str = None):
84
+ with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
85
+ for event in json.load(reader)["Events"]:
86
+ turn_speaker = event["Agent"]
87
+ if speaker == None and turn_speaker in STAR._speakers:
88
+ return Turn(speaker=turn_speaker, text=event["Text"])
89
+ elif turn_speaker == speaker:
90
+ return Turn(speaker=turn_speaker, text=event["Text"])
91
+
92
+ @staticmethod
93
+ def get_dialog_task_names(id):
94
+ scenario = STAR.get_dialog_scenario(id)
95
+ return [task["Task"] for task in scenario["WizardCapabilities"]]
96
+
97
+ @staticmethod
98
+ def get_dialog_responses(id):
99
+ tasks = STAR.get_dialog_task_names(id)
100
+ return [STAR.read_graph_responses(task, as_dict=True) for task in tasks]
101
+
102
+ @staticmethod
103
+ def get_dialog_graphs(id):
104
+ tasks = STAR.get_dialog_task_names(id)
105
+ return [STAR.read_graph(task, as_dot=False) for task in tasks]
106
+
107
+ @staticmethod
108
+ def get_dialog_events(id):
109
+ with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
110
+ return json.load(reader)["Events"]
111
+
112
+ @staticmethod
113
+ def get_dialog_events(id):
114
+ with open(os.path.join(STAR._path, f"dialogues/{id}.json")) as reader:
115
+ return json.load(reader)["Events"]
116
+
117
+ @staticmethod
118
+ def get_dialog_user_instructions(id):
119
+ def get_user_n_turns_before(turn_ix, events):
120
+ return len([e for e in events[:turn_ix]
121
+ if e["Agent"] == "User" and e["Action"] == "utter"])
122
+ events = STAR.get_dialog_events(id)
123
+ return {get_user_n_turns_before(ix, events): e["Text"]
124
+ for ix, e in enumerate(events)
125
+ if e["Action"] == "instruct" and e["Agent"] == "UserGuide"}
126
+
127
+ @staticmethod
128
+ def get_dialog_graphs_and_responses(id):
129
+ return STAR.get_dialog_graphs(id), STAR.get_dialog_responses(id)
130
+
131
+ @staticmethod
132
+ def get_scenario_description(scenario):
133
+ # Let's generate the graph description for each task:
134
+ flowcharts = ""
135
+ for task in scenario["WizardCapabilities"]:
136
+ task_name = task["Task"]
137
+ flowcharts += f"""
138
+ The graph for the task '{task_name}' with domain '{task['Domain']}' is:
139
+ ```dot
140
+ {STAR.read_graph(task_name)}
141
+ ```
142
+ and one example responses for each node is provided in the following json:
143
+ ```json
144
+ {STAR.read_graph_responses(task_name)}
145
+ ```
146
+
147
+ ---
148
+ """
149
+ # Finally, let's return the scenario object and natural language description for it.
150
+ return f"""The conversation is between a User and a AI assistant in the following domains: {', '.join(scenario['Domains'])}.
151
+
152
+ The User instructions are: {scenario['UserTask']}
153
+ The AI assistant instructions are: {scenario['WizardTask']}
154
+
155
+ In addition, the AI assistant is instructed to follow specific flowcharts to address the tasks. Flowcharts are defined as graph described using DOT.
156
+ The actual DOT for the current tasks are:
157
+ {flowcharts}
158
+
159
+ Finally, the following should be considered regarding the conversation:
160
+ 1. {"The conversation follows the 'happy path', meaning the conversations goes according to what it is described in the flowcharts"
161
+ if scenario['Happy'] else
162
+ "The conversation does NOT follow a 'happy path', meaning something happend to the user to change its mind or something happend "
163
+ "in the environment for the conversation to not flow as expected, as described in the flowchart"}.
164
+ 2. {"The user is calling to perform multiple tasks, involving all the tasks defined as flowcharts above (" + ', '.join(task['Task'] for task in scenario['WizardCapabilities']) + ")"
165
+ if scenario['MultiTask'] else
166
+ "The user is calling to perform only the defined task (" + scenario['WizardCapabilities'][0]['Task'] + "), nothing else"}.
167
+ """
168
+
169
+ @staticmethod
170
+ def get_dialog_scenario_description(id):
171
+ scenario = STAR.get_dialog_scenario(id)
172
+ return scenario, STAR.get_scenario_description(scenario)
173
+
174
+ @staticmethod
175
+ def get_user_persona_for_scenario(scenario):
176
+ dialogue_details = f"""
177
+ The following should be considered regarding the conversation:
178
+ 1. {"The conversation follows a 'happy path', meaning the conversations goes smoothly without any unexpected behavior"
179
+ if scenario['Happy'] else
180
+ "The conversation does NOT follow a 'happy path', meaning you have to simulate something happend in the middle of the conversation, "
181
+ "perhaps you changed your mind at some point or something external happend in the environment for the conversation to not flow as expected"}.
182
+ 2. {"The conversation involves multiple tasks, that is, you want the assistant to perform multiple tasks (" + ', '.join(task['Task'] for task in scenario['WizardCapabilities']) + "), not just one."
183
+ if scenario['MultiTask'] else
184
+ "The conversation involves only one task you were instructed to (" + scenario['WizardCapabilities'][0]['Task'] + "), nothing else"}"""
185
+
186
+ return Persona(
187
+ role=f"user calling a AI assistant that can perform multiple tasks in the following domains: {', '.join(scenario['Domains'])}.\n" + dialogue_details,
188
+ circumstances=scenario["UserTask"],
189
+ )
190
+
191
+ @staticmethod
192
+ def get_flowchart_description_for_scenario(scenario):
193
+ flowcharts = ""
194
+ for task in scenario["WizardCapabilities"]:
195
+ task_name = task["Task"]
196
+ flowcharts += f"""
197
+ ## {task_name} ({task['Domain']})
198
+
199
+ The flowchart described as an action transition graph for the task '{task_name}' with domain '{task['Domain']}' is:
200
+ ```dot
201
+ {STAR.read_graph(task_name)}
202
+ ```
203
+ Response example for each action is provided in the following json:
204
+ ```json
205
+ {STAR.read_graph_responses(task_name)}
206
+ ```
207
+ where UPPERCASE words above are just example placeholders. You MUST fill in those with any coherent values in the actual conversation.
208
+ """
209
+ return flowcharts
210
+
211
+ @staticmethod
212
+ def get_system_persona_for_scenario(scenario):
213
+
214
+
215
+ dialogue_details = f"""In the conversation, the AI assistant is instructed to follow specific action flowcharts to address the tasks. Flowcharts are defined as graph described using DOT.
216
+ The actual DOT for the current tasks are:
217
+ {STAR.get_flowchart_description_for_scenario(scenario)}
218
+ """
219
+ return Persona(
220
+ role="AI assistant.\n" + dialogue_details,
221
+ circumstances=scenario['WizardTask'],
222
+ )
223
+
224
+ @staticmethod
225
+ def get_agents_for_scenario(scenario, model_name):
226
+ user = PersonaAgent(model_name,
227
+ STAR.get_user_persona_for_scenario(scenario),
228
+ name="User",
229
+ can_finish=True)
230
+
231
+ system = PersonaAgent(model_name,
232
+ STAR.get_system_persona_for_scenario(scenario),
233
+ name="System")
234
+
235
+ return system, user
236
+
237
+ @staticmethod
238
+ def get_agents_from_dialogue(id, model_name:str, set_first_utterance: bool = False):
239
+ scenario = STAR.get_dialog_scenario(id)
240
+ system, user = STAR.get_agents_for_scenario(scenario, model_name)
241
+
242
+ if set_first_utterance:
243
+ first_turn = STAR.get_dialog_first_turn(id)
244
+ if first_turn.speaker == "Wizard":
245
+ system.set_first_utterances(first_turn.text)
246
+ else:
247
+ system.set_first_utterances("Hello, how can I help?")
248
+
249
+ return system, user
250
+
251
+ @staticmethod
252
+ def get_agents_from_dialogue_with_orchestration(id, model_name:str, set_first_utterance: bool = False):
253
+ system, user = STAR.get_agents_from_dialogue(id, model_name, set_first_utterance)
254
+
255
+ graphs, responses = STAR.get_dialog_graphs_and_responses(id)
256
+ response_action_orchestrator = SimpleResponseOrchestrator(responses[0], graph=graphs[0])
257
+ instr_list_orchestrator = InstructionListOrchestrator(
258
+ STAR.get_dialog_user_instructions(id),
259
+ persistent=True
260
+ )
261
+
262
+ return system | response_action_orchestrator, user | instr_list_orchestrator
@@ -0,0 +1,82 @@
1
+ import json
2
+ import random
3
+
4
+ from langchain_ollama.chat_models import ChatOllama
5
+ from langchain_core.messages import HumanMessage, SystemMessage
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from print_color import print
10
+ from typing import Union, List
11
+ from langchain_ollama.chat_models import ChatOllama
12
+ from langchain_core.messages import HumanMessage, SystemMessage
13
+
14
+ from . import Dialog, Turn
15
+
16
+
17
+ class LLMDialogOutput(BaseModel):
18
+ dialog: List[Turn]
19
+
20
+
21
+ # TODO: create a BaseDialogGenerator, and also PersonaDialogGenerator that takes personas objects as in multi-agent.
22
+ class DialogGenerator:
23
+ def __init__(self, model: Union[ChatOllama, str], dialogue_details: str, output_format: Union[dict, BaseModel] = LLMDialogOutput, scenario: dict = None):
24
+ """Optional `scenario` to populate the "scenario" field of the output, if not provided, `dialogue_details` content will be used."""
25
+
26
+ if not output_format or type(output_format) == dict:
27
+ output_format_schema = output_format
28
+ self.output_format = None
29
+ else:
30
+ output_format_schema = output_format.model_json_schema()
31
+ self.output_format = output_format
32
+
33
+ if type(model) == str:
34
+ self.llm = ChatOllama(model=model,
35
+ format=output_format_schema,
36
+ temperature=0.8,
37
+ seed=13)
38
+ else:
39
+ self.llm = model
40
+ if output_format:
41
+ self.llm.format = output_format
42
+
43
+ self.model_name = model
44
+ self.set(dialogue_details, scenario)
45
+
46
+ def generate(self, seed: int = None, id: int = None):
47
+ self.llm.seed = seed if seed is not None else random.getrandbits(32)
48
+
49
+ # hack to avoid seed bug in prompt cache (to force a new cache, related to https://github.com/ollama/ollama/issues/5321)
50
+ _ = self.llm.num_predict
51
+ self.llm.num_predict = 1
52
+ self.llm.invoke(self.messages)
53
+ self.llm.num_predict = _
54
+
55
+ dialogue = self.llm.invoke(self.messages).content
56
+
57
+ if not self.output_format:
58
+ return dialogue
59
+ else:
60
+ llm_output = self.output_format.model_validate(json.loads(dialogue))
61
+
62
+ if self.output_format is LLMDialogOutput:
63
+ return Dialog(dialogId=id if id else None,
64
+ model=self.model_name,
65
+ seed=self.llm.seed,
66
+ scenario=self.scenario if self.scenario else self.dialogue_details,
67
+ turns=llm_output.dialog)
68
+ else:
69
+ return llm_output
70
+
71
+ def set(self, dialogue_details: str, scenario:dict=None):
72
+ self.scenario = scenario
73
+ self.dialogue_details = dialogue_details
74
+ self.messages = [
75
+ SystemMessage(
76
+ "You are a knowledgeable and useful AI assistant that can write natural conversations by role paying different speakers."
77
+ "The output should be a full dialogue, from begining (greetings) to end (bye bye messages)."
78
+ ),
79
+ HumanMessage(content=dialogue_details)
80
+ ]
81
+
82
+ __call__ = generate
@@ -0,0 +1,224 @@
1
+ import json
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+
6
+ from time import time
7
+ from abc import ABC, abstractmethod
8
+ from typing import List, Union, Dict, Optional
9
+ from sentence_transformers import SentenceTransformer
10
+ from langchain_core.messages import SystemMessage, AIMessage
11
+
12
+ from . import Turn, Event, Instruction
13
+ from .util import make_serializable
14
+ # from .personas import PersonaAgent
15
+
16
+
17
+ class BaseOrchestrator(ABC):
18
+ _target = None
19
+ _event_label = None
20
+ _persistent = False
21
+
22
+ def __init__(self, target_agent = None, persistent: bool = None, event_label: str = None):
23
+ self._target = target_agent
24
+ self._persistent = persistent
25
+ self._event_label = event_label
26
+
27
+ def __call__(self):
28
+ dialog = self.__get_current_dialog()
29
+ return self.instruct(dialog, dialog[-1].text if dialog and dialog[-1].speaker != self._target.get_name() else "")
30
+
31
+ def __str__(self) -> str:
32
+ data = self.json()
33
+ attrs = " ".join(f"{key}={value}" for key, value in data["args"].items())
34
+ return f"{data['name']}({attrs})"
35
+
36
+ def __get_current_dialog(self) -> List[Turn]:
37
+ return [Turn(speaker=self._target.get_name() if type(message) == AIMessage else None, text=message.content)
38
+ for message in self._target.memory if type(message) != SystemMessage]
39
+
40
+ def _set_target_agent(self, agent): # target: PersonaAgent
41
+ self._target = agent
42
+
43
+ def json(self, string: bool = False, indent: int =None):
44
+ sig = inspect.signature(self.__init__)
45
+ data = {"name": type(self).__name__,
46
+ "args": {key: self.__dict__[key] for key in sig.parameters
47
+ if key in self.__dict__ and self.__dict__[key] is not None}}
48
+ make_serializable(data["args"])
49
+ return json.dumps(data, indent=indent) if string else data
50
+
51
+ def get_event_label(self) -> str:
52
+ return self._event_label if self._event_label else type(self).__name__
53
+
54
+ def get_target_agent(self):
55
+ return self._target
56
+
57
+ def is_persistent(self):
58
+ return self._persistent
59
+
60
+ def set_persistent(self, value: bool):
61
+ self._persistent = value
62
+
63
+ def agent_response_lookahead(self):
64
+ return self._target.response_lookahead()
65
+
66
+ @abstractmethod
67
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
68
+ pass
69
+
70
+ def reset(self):
71
+ pass
72
+
73
+
74
+ class BasePersistentOrchestrator(BaseOrchestrator): #, ABC):
75
+ _persistent = True
76
+
77
+ @abstractmethod
78
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
79
+ pass
80
+
81
+ def reset(self):
82
+ pass
83
+
84
+
85
+ class SimpleReflexOrchestrator(BaseOrchestrator):
86
+ def __init__(self, condition: callable, instruction: str, persistent: bool = False, event_label: str = None):
87
+ super().__init__(persistent=persistent, event_label=event_label)
88
+ self.condition = condition
89
+ self.instruction = instruction
90
+
91
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
92
+ if self.condition(utterance):
93
+ return self.instruction
94
+
95
+
96
+ class LengthOrchestrator(BaseOrchestrator):
97
+ def __init__(self, min: int = None, max: int = None, persistent: bool = False, event_label: str = None):
98
+ super().__init__(persistent=persistent, event_label=event_label)
99
+ self.max = max
100
+ self.min = min
101
+
102
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
103
+ if self.min is not None and len(dialog) < self.min and len(dialog) > 1:
104
+ return "Make sure you DO NOT finish the conversation, keep it going!"
105
+ elif self.max and len(dialog) >= self.max - 1: # + answer
106
+ return "Now FINISH the conversation AS SOON AS possible, if possible, RIGHT NOW!"
107
+
108
+
109
+ class ChangeMindOrchestrator(BaseOrchestrator):
110
+ def __init__(self, probability: float = 0.3,
111
+ reasons: Union[str, List[str]] = None,
112
+ max_times: int = 1,
113
+ persistent: bool = False,
114
+ event_label: str = None):
115
+ super().__init__(persistent=persistent, event_label=event_label)
116
+ self.probability = probability
117
+ self.reasons = [reasons] if type(reasons) == str else reasons
118
+ self.max_times = max_times
119
+ self.times = 0
120
+
121
+ def reset(self):
122
+ self.times = 0
123
+
124
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
125
+ if self.max_times and self.times >= self.max_times:
126
+ return
127
+
128
+ if random.random() <= self.probability:
129
+ self.times += 1
130
+ instruction = "Change your mind completely, in your next utterance, suggest something completely different!"
131
+ if self.reasons:
132
+ instruction += f" **Reason:** {random.choice(self.reasons)}."
133
+ return instruction
134
+
135
+
136
+ class SimpleResponseOrchestrator(BaseOrchestrator):
137
+ def __init__(self,
138
+ responses: List[Union[str, Dict[str, str]]],
139
+ graph: Dict[str, str] = None,
140
+ # sbert_model: str = "sentence-transformers/LaBSE",
141
+ sbert_model: str = "sergioburdisso/dialog2flow-joint-bert-base",
142
+ top_k: int = 5):
143
+
144
+ self.sent_encoder = SentenceTransformer(sbert_model)
145
+ self.responses = responses
146
+ self.top_k = top_k
147
+
148
+ if type(responses) == dict:
149
+ self.resp_utts = np.array([resp for resp in responses.values()])
150
+ self.resp_acts = np.array([act for act in responses.keys()])
151
+ self.graph = graph
152
+ else:
153
+ self.resp_utts = np.array(responses)
154
+ self.resp_acts = None
155
+ self.graph = None
156
+
157
+ self.resp_utt_embs = self.sent_encoder.encode(self.resp_utts)
158
+
159
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
160
+ agent = self.get_target_agent()
161
+
162
+ agent_last_turn = None
163
+ if self.graph and dialog:
164
+ for turn in dialog[::-1]:
165
+ if turn.speaker == agent.get_name():
166
+ agent_last_turn = turn.text
167
+ break
168
+
169
+ response = agent_last_turn if agent_last_turn else agent.response_lookahead()
170
+
171
+ events = [Event(agent=agent.get_name(),
172
+ action="request_suggestions",
173
+ actionLabel=self.get_event_label(),
174
+ text=f'Previous response: "{response}"' if agent_last_turn else f'Lookahead response: "{response}"',
175
+ timestamp=int(time()))]
176
+
177
+ sims = self.sent_encoder.similarity(self.sent_encoder.encode(response), self.resp_utt_embs)[0]
178
+ top_k_ixs = sims.argsort(descending=True)[:self.top_k]
179
+
180
+ if self.resp_acts is None:
181
+ instruction = ("If applicable, try to pick your next response from the following list: " +
182
+ "; ".join(f'({ix + 1}) {resp}' for ix, resp in enumerate(self.resp_utts[top_k_ixs])))
183
+ else:
184
+ next_actions = self.resp_acts[top_k_ixs].tolist()
185
+ events.append(Event(agent=agent.get_name(),
186
+ action="request_suggestions",
187
+ actionLabel=self.get_event_label(),
188
+ text="Actions for the response: "+ ", ".join(action for action in next_actions),
189
+ timestamp=int(time())))
190
+ if agent_last_turn:
191
+ next_actions = [self.graph[action] if action in self.graph else action
192
+ for action in next_actions]
193
+ events.append(Event(agent=agent.get_name(),
194
+ action="request_suggestions",
195
+ actionLabel=self.get_event_label(),
196
+ text="Graph next actions: " + ", ".join(action for action in next_actions),
197
+ timestamp=int(time())))
198
+
199
+ # TODO: remove repeated actions! (make it a set()?)
200
+ next_actions = [action for action in next_actions if action in self.responses]
201
+ instruction = ("If applicable, pick your next response from the following action list in order of importance: " +
202
+ "; ".join(f'({ix + 1}) Action: {action}. Response: "{self.responses[action]}"' for ix, action in enumerate(next_actions)))
203
+
204
+ return Instruction(text=instruction, events=events)
205
+
206
+
207
+ class InstructionListOrchestrator(BaseOrchestrator):
208
+ def __init__(self,
209
+ instructions: List[Union[str, Dict[int, str]]],
210
+ persistent: bool = False):
211
+ super().__init__(persistent=persistent)
212
+ self.instructions = instructions
213
+
214
+ def instruct(self, dialog: List[Turn], utterance: str) -> str:
215
+ agent = self.get_target_agent()
216
+
217
+ if dialog:
218
+ current_user_len = len([t for t in dialog if t.speaker == agent.get_name()])
219
+ else:
220
+ current_user_len = 0
221
+
222
+ if (type(self.instructions) == dict and current_user_len in self.instructions) or \
223
+ (type(self.instructions) == list and current_user_len < len(self.instructions)):
224
+ return self.instructions[current_user_len]
@@ -0,0 +1,299 @@
1
+ import json
2
+ import random
3
+
4
+ from time import time
5
+ from tqdm.auto import trange
6
+ from print_color import print
7
+ from typing import List, Union, Optional
8
+ from langchain_ollama.chat_models import ChatOllama
9
+ # from langchain_core.prompts import PromptTemplate
10
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
11
+
12
+ from . import Dialog, Turn, Event, Instruction
13
+ from .orchestrators import BaseOrchestrator
14
+ from .util import make_serializable
15
+
16
+
17
+ class Meta(type):
18
+ def __init__(cls, name, bases, dct):
19
+ def auto__call__init__(self, *a, **kw):
20
+ for base in cls.__bases__:
21
+ base.__init__(self, *a, **kw)
22
+ cls.__init__child_(self, *a, **kw)
23
+ cls.__init__child_ = cls.__init__
24
+ cls.__init__ = auto__call__init__
25
+
26
+
27
+ class BasePersona(metaclass=Meta):
28
+ def __init__(self, **kwargs):
29
+ self.__dict__.update(kwargs)
30
+
31
+ def description(self) -> str:
32
+ return "\n".join(f"Your {key}: {value}" for key, value in self.__dict__.items())
33
+
34
+ def __str__(self) -> str:
35
+ return self.description()
36
+
37
+ def json(self, string: bool = False, indent=None):
38
+ data = self.__dict__.copy()
39
+ make_serializable(data)
40
+ return json.dumps(data, indent=indent) if string else data
41
+
42
+
43
+ class Persona(BasePersona):
44
+ name: str = ""
45
+ role: str = ""
46
+ background: str = ""
47
+ personality: str = ""
48
+ circumstances: str = ""
49
+ rules: str = ""
50
+
51
+
52
+ class PersonaAgent:
53
+
54
+ STOP_WORD = "STOP"
55
+ STOP_WORD_TEXT = "(bye bye!)"
56
+
57
+ def __init__(self,
58
+ model : Union[str, ChatOllama],
59
+ persona: BasePersona = Persona(),
60
+ name: str = None,
61
+ dialogue_details: str = "",
62
+ response_details: str = "responses SHOULD NOT be too long and wordy, should be approximately one utterance long",
63
+ system_prompt: str = None,
64
+ can_finish: bool = False,
65
+ orchestrators: Union[BaseOrchestrator, List[BaseOrchestrator]] = None,
66
+ scenario: Union[dict, str] = None):
67
+
68
+ if not system_prompt:
69
+ if can_finish:
70
+ conversation_end_instructions = f"To finish the conversation you first have to say good bye and immediately after you **MUST** output '{self.STOP_WORD}' to indicate it is the end of it."
71
+ else:
72
+ conversation_end_instructions = "When the user finish the conversation you should say good bye and also finish the conversation"
73
+
74
+ # system_prompt = prompt_template.format(role=role, ...)
75
+ system_prompt = f"""Role play as a character that is described by the persona defined in the following lines. You always stay in character.
76
+ [[ ## BEGING PERSONA ## ]]
77
+ {persona}
78
+ [[ ## END PERSONA ## ]]
79
+ ---
80
+ {"Details about the overall dialogue: " + dialogue_details if dialogue_details else ""}
81
+ {"Details about your responses: " + response_details if response_details else ""}
82
+ Finally, remember:
83
+ 1. You always stay on character. You are the character described above.
84
+ 2. Your first utterance / turn MUST always be a short generic greeting (e.g. "Hello, how are you?", "Hi!", "hey! what's up?", etc.), and nothing else, wait for a reply before start with the actual conversation.
85
+ 3. {conversation_end_instructions}."""
86
+
87
+ if type(model) == str:
88
+ # TODO: ChatHuggingFace
89
+ self.llm = ChatOllama(model=model,
90
+ temperature=0.8,
91
+ seed=13)
92
+ else:
93
+ self.llm = model
94
+ self.memory = [SystemMessage(system_prompt)]
95
+
96
+ self.name = name if name else (persona.name if hasattr(persona, "name") else None)
97
+ self.persona = persona
98
+ self.model_name = str(self.llm)
99
+ self.first_utterances = None
100
+ self.finished = False
101
+ self.scenario = scenario
102
+ self.orchestrators = None
103
+ self.add_orchestrators(orchestrators)
104
+
105
+ def __call__(self, utterance: str = "", return_events: bool = False) -> str:
106
+ if self.finished:
107
+ return None
108
+
109
+ if utterance:
110
+ self.memory.append(HumanMessage(content=utterance))
111
+
112
+ if return_events: events = []
113
+ if self.orchestrators:
114
+ for orchestrator in self.orchestrators:
115
+ instruction = orchestrator()
116
+ if instruction:
117
+
118
+ if type(instruction) == Instruction:
119
+ if return_events and instruction.events:
120
+ if type(instruction.events) == Event: events.append(instruction.events)
121
+ else: events.extend(instruction.events)
122
+ instruction = instruction.text
123
+
124
+ persist = orchestrator.is_persistent()
125
+ self.instruct(instruction, persist=persist)
126
+ if return_events:
127
+ events.append(Event(agent=self.get_name(),
128
+ action="instruct" + ("-persist" if persist else ""),
129
+ actionLabel=orchestrator.get_event_label(),
130
+ text=instruction,
131
+ timestamp=int(time())))
132
+
133
+ if len(self.memory) <= 1 and self.first_utterances:
134
+ response = random.choice(self.first_utterances) if type(self.first_utterances) == list else self.first_utterances
135
+ response = AIMessage(content=response)
136
+ else:
137
+ response = self.llm.invoke(self.memory)
138
+
139
+ if self.orchestrators:
140
+ self.memory[:] = [msg for msg in self.memory
141
+ if not (msg.response_metadata and "persist" in msg.response_metadata and not msg.response_metadata["persist"])]
142
+ self.memory.append(response)
143
+
144
+ response = response.content
145
+ if self.STOP_WORD in response:
146
+ response = response.replace(self.STOP_WORD, self.STOP_WORD_TEXT).strip()
147
+ self.memory[-1].content = self.memory[-1].content.replace(self.STOP_WORD, "").strip()
148
+ self.finished = True
149
+
150
+ if return_events:
151
+ if response:
152
+ events.append(Event(agent=self.get_name(),
153
+ action="utter",
154
+ text=response,
155
+ timestamp=int(time())))
156
+ return events
157
+ else:
158
+ return response if response else ""
159
+
160
+ def __or__(self, orchestrator: Union[BaseOrchestrator, List[BaseOrchestrator]]):
161
+ self.add_orchestrators(orchestrator)
162
+ return self
163
+
164
+ def response_lookahead(self, utterance: str = None):
165
+ if not utterance:
166
+ return self.llm.invoke(self.memory).content
167
+ return self.llm.invoke(self.memory + [HumanMessage(utterance)]).content
168
+
169
+ def add_orchestrators(self, orchestrators):
170
+ if not orchestrators:
171
+ return
172
+
173
+ if self.orchestrators == None:
174
+ self.orchestrators = []
175
+
176
+ if isinstance(orchestrators, BaseOrchestrator):
177
+ orchestrators = [orchestrators]
178
+
179
+ self.orchestrators.extend(orchestrators)
180
+
181
+ for orchestrator in orchestrators:
182
+ orchestrator._set_target_agent(self)
183
+
184
+ def clear_orchestrators(self):
185
+ self.orchestrators = None
186
+
187
+ def instruct(self, instruction: str, persist: bool = False):
188
+ self.memory.append(SystemMessage(instruction, response_metadata={"persist": persist}))
189
+
190
+ def set_first_utterances(self, utterances: Union[str, List[str]]):
191
+ self.first_utterances = utterances
192
+
193
+ def get_name(self):
194
+ return self.name
195
+
196
+ def get_prompt(self):
197
+ return self.memory[0].content
198
+
199
+ def json(self, string: bool = False, indent=None):
200
+ data = {}
201
+ if self.name:
202
+ data["name"] = self.name
203
+ data["model_name"] = self.model_name
204
+ if self.first_utterances:
205
+ data["first_utterances"] = self.first_utterances
206
+ data["persona"] = self.persona.json()
207
+ if self.orchestrators:
208
+ data["persona"]["orchestrators"] = [orc.json() for orc in self.orchestrators]
209
+ return json.dumps(data, indent=indent) if string else data
210
+
211
+ def reset(self, seed:int = None):
212
+ self.memory[:] = self.memory[:1]
213
+ self.finished = False
214
+ self.llm.seed = seed
215
+
216
+ if self.orchestrators:
217
+ for orchestrator in self.orchestrators:
218
+ orchestrator.reset()
219
+
220
+ # hack to avoid seed bug in prompt cache (to force a new cache, related to https://github.com/ollama/ollama/issues/5321)
221
+ _ = self.llm.num_predict
222
+ self.llm.num_predict = 1
223
+ self.llm.invoke(self.memory)
224
+ self.llm.num_predict = _
225
+
226
+ def dialog_with(self,
227
+ persona: "PersonaAgent",
228
+ max_iterations: int = 20,
229
+ id: int = None,
230
+ seed: int = None,
231
+ keep_bar: bool = True):
232
+ seed = seed if seed is not None else random.getrandbits(32)
233
+
234
+ random.seed(seed)
235
+ self.reset(seed)
236
+ persona.reset(seed)
237
+
238
+ dialog = []
239
+ events = []
240
+
241
+ utter = None
242
+ completion = False
243
+ tqdm_iterator = trange(max_iterations, desc="Dialogue", leave=keep_bar)
244
+ for _ in tqdm_iterator:
245
+ utt_events = self(utter, return_events=True)
246
+
247
+ if utt_events and utt_events[-1].action == "utter":
248
+ utter = utt_events[-1].text
249
+ utt_events[-1].text = utter.replace(self.STOP_WORD_TEXT, "").strip()
250
+ if not utt_events[-1].text: break
251
+ else:
252
+ completion = True
253
+ break
254
+
255
+ dialog.append(Turn(
256
+ speaker=self.get_name() if self.get_name() else "Me",
257
+ text=utt_events[-1].text
258
+ ))
259
+ events.extend(utt_events)
260
+
261
+ utt_events = persona(utter, return_events=True)
262
+ if utt_events and utt_events[-1].action == "utter":
263
+ utter = utt_events[-1].text
264
+ utt_events[-1].text = utter.replace(self.STOP_WORD_TEXT, "").strip()
265
+ if not utt_events[-1].text: break
266
+ else:
267
+ completion = True
268
+ break
269
+
270
+ dialog.append(Turn(
271
+ speaker=persona.get_name() if persona.get_name() else "Other",
272
+ text=utt_events[-1].text
273
+ ))
274
+ events.extend(utt_events)
275
+
276
+ if not keep_bar:
277
+ tqdm_iterator.container.close()
278
+
279
+ if self.scenario:
280
+ scenario = self.scenario
281
+ else:
282
+ scenario = {
283
+ "agents": [
284
+ self.json(),
285
+ persona.json()
286
+ ]
287
+ }
288
+
289
+ return Dialog(
290
+ dialogId=id if id else None,
291
+ complete=completion, # incomplete if ran out of iterations (reached max_iteration number)
292
+ model=self.model_name,
293
+ seed=seed,
294
+ scenario=scenario,
295
+ turns=dialog,
296
+ events=events
297
+ )
298
+
299
+ talk_with = dialog_with
@@ -0,0 +1,11 @@
1
+ import json
2
+
3
+ def make_serializable(data:dict):
4
+
5
+ for key, value in data.items():
6
+ try:
7
+ json.dumps(value)
8
+ except (TypeError, OverflowError):
9
+ data[key] = str(value)
10
+
11
+ return data
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: sdialog
3
+ Version: 0.0.1
4
+ Summary: Synthetic Dialogue Generation and Analysis
5
+ Author-email: Sergio Burdisso <sergio.burdisso@gmail.com>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/idiap/sdialog
8
+ Project-URL: Issues, https://github.com/idiap/sdialog/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.9
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: print-color
15
+ Requires-Dist: langchain
16
+ Requires-Dist: langchain-ollama
17
+ Requires-Dist: tqdm
18
+ Requires-Dist: plotly
19
+ Requires-Dist: sentence-transformers
20
+ Requires-Dist: pandas
21
+ Requires-Dist: tenacity
22
+ Requires-Dist: numpy
23
+ Requires-Dist: flake8
24
+ Requires-Dist: pytest
25
+ Requires-Dist: ollama
26
+ Dynamic: license-file
27
+
28
+ # SDialog
29
+
30
+ Synthetic Dialogue Generation and Analysis
31
+
32
+ _(Comming soon)_
33
+
34
+ This package requires `Ollama` running is your system.
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ requirements.txt
5
+ src/sdialog/__init__.py
6
+ src/sdialog/datasets.py
7
+ src/sdialog/generators.py
8
+ src/sdialog/orchestrators.py
9
+ src/sdialog/personas.py
10
+ src/sdialog/util.py
11
+ src/sdialog.egg-info/PKG-INFO
12
+ src/sdialog.egg-info/SOURCES.txt
13
+ src/sdialog.egg-info/dependency_links.txt
14
+ src/sdialog.egg-info/requires.txt
15
+ src/sdialog.egg-info/top_level.txt
@@ -0,0 +1,12 @@
1
+ print-color
2
+ langchain
3
+ langchain-ollama
4
+ tqdm
5
+ plotly
6
+ sentence-transformers
7
+ pandas
8
+ tenacity
9
+ numpy
10
+ flake8
11
+ pytest
12
+ ollama
@@ -0,0 +1 @@
1
+ sdialog