PraisonAI 0.0.59rc2__cp312-cp312-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PraisonAI might be problematic. Click here for more details.
- praisonai/__init__.py +6 -0
- praisonai/__main__.py +10 -0
- praisonai/agents_generator.py +381 -0
- praisonai/auto.py +190 -0
- praisonai/chainlit_ui.py +304 -0
- praisonai/cli.py +337 -0
- praisonai/deploy.py +138 -0
- praisonai/inbuilt_tools/__init__.py +2 -0
- praisonai/inbuilt_tools/autogen_tools.py +209 -0
- praisonai/inc/__init__.py +2 -0
- praisonai/inc/models.py +128 -0
- praisonai/public/android-chrome-192x192.png +0 -0
- praisonai/public/android-chrome-512x512.png +0 -0
- praisonai/public/apple-touch-icon.png +0 -0
- praisonai/public/fantasy.svg +3 -0
- praisonai/public/favicon-16x16.png +0 -0
- praisonai/public/favicon-32x32.png +0 -0
- praisonai/public/favicon.ico +0 -0
- praisonai/public/game.svg +3 -0
- praisonai/public/logo_dark.png +0 -0
- praisonai/public/logo_light.png +0 -0
- praisonai/public/movie.svg +3 -0
- praisonai/public/thriller.svg +3 -0
- praisonai/test.py +105 -0
- praisonai/train.py +232 -0
- praisonai/ui/chat.py +304 -0
- praisonai/ui/code.py +318 -0
- praisonai/ui/context.py +283 -0
- praisonai/ui/public/fantasy.svg +3 -0
- praisonai/ui/public/game.svg +3 -0
- praisonai/ui/public/logo_dark.png +0 -0
- praisonai/ui/public/logo_light.png +0 -0
- praisonai/ui/public/movie.svg +3 -0
- praisonai/ui/public/thriller.svg +3 -0
- praisonai/ui/sql_alchemy.py +638 -0
- praisonai/version.py +1 -0
- praisonai-0.0.59rc2.dist-info/LICENSE +20 -0
- praisonai-0.0.59rc2.dist-info/METADATA +344 -0
- praisonai-0.0.59rc2.dist-info/RECORD +41 -0
- praisonai-0.0.59rc2.dist-info/WHEEL +4 -0
- praisonai-0.0.59rc2.dist-info/entry_points.txt +5 -0
praisonai/train.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import yaml
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import TextStreamer
|
|
7
|
+
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
8
|
+
from trl import SFTTrainer
|
|
9
|
+
from transformers import TrainingArguments
|
|
10
|
+
from datasets import load_dataset, concatenate_datasets, Dataset
|
|
11
|
+
from psutil import virtual_memory
|
|
12
|
+
|
|
13
|
+
class train:
|
|
14
|
+
def __init__(self, config_path="config.yaml"):
|
|
15
|
+
self.load_config(config_path)
|
|
16
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
17
|
+
self.model, self.tokenizer = None, None
|
|
18
|
+
|
|
19
|
+
def load_config(self, path):
|
|
20
|
+
with open(path, "r") as file:
|
|
21
|
+
self.config = yaml.safe_load(file)
|
|
22
|
+
|
|
23
|
+
def print_system_info(self):
|
|
24
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
25
|
+
print(f"CUDA version: {torch.version.cuda}")
|
|
26
|
+
if torch.cuda.is_available():
|
|
27
|
+
device_capability = torch.cuda.get_device_capability()
|
|
28
|
+
print(f"CUDA Device Capability: {device_capability}")
|
|
29
|
+
else:
|
|
30
|
+
print("CUDA is not available")
|
|
31
|
+
|
|
32
|
+
python_version = sys.version
|
|
33
|
+
pip_version = subprocess.check_output(['pip', '--version']).decode().strip()
|
|
34
|
+
python_path = sys.executable
|
|
35
|
+
pip_path = subprocess.check_output(['which', 'pip']).decode().strip()
|
|
36
|
+
print(f"Python Version: {python_version}")
|
|
37
|
+
print(f"Pip Version: {pip_version}")
|
|
38
|
+
print(f"Python Path: {python_path}")
|
|
39
|
+
print(f"Pip Path: {pip_path}")
|
|
40
|
+
|
|
41
|
+
def check_gpu(self):
|
|
42
|
+
gpu_stats = torch.cuda.get_device_properties(0)
|
|
43
|
+
print(f"GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)} GB.")
|
|
44
|
+
|
|
45
|
+
def check_ram(self):
|
|
46
|
+
ram_gb = virtual_memory().total / 1e9
|
|
47
|
+
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
|
|
48
|
+
if ram_gb < 20:
|
|
49
|
+
print('Not using a high-RAM runtime')
|
|
50
|
+
else:
|
|
51
|
+
print('You are using a high-RAM runtime!')
|
|
52
|
+
|
|
53
|
+
# def install_packages(self):
|
|
54
|
+
# subprocess.run(["pip", "install", "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@4e570be9ae4ced8cdc64e498125708e34942befc"])
|
|
55
|
+
# subprocess.run(["pip", "install", "--no-deps", "trl<0.9.0", "peft==0.12.0", "accelerate==0.33.0", "bitsandbytes==0.43.3"])
|
|
56
|
+
|
|
57
|
+
def prepare_model(self):
|
|
58
|
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
|
59
|
+
model_name=self.config["model_name"],
|
|
60
|
+
max_seq_length=self.config["max_seq_length"],
|
|
61
|
+
dtype=None,
|
|
62
|
+
load_in_4bit=self.config["load_in_4bit"]
|
|
63
|
+
)
|
|
64
|
+
self.model = FastLanguageModel.get_peft_model(
|
|
65
|
+
self.model,
|
|
66
|
+
r=self.config["lora_r"],
|
|
67
|
+
target_modules=self.config["lora_target_modules"],
|
|
68
|
+
lora_alpha=self.config["lora_alpha"],
|
|
69
|
+
lora_dropout=self.config["lora_dropout"],
|
|
70
|
+
bias=self.config["lora_bias"],
|
|
71
|
+
use_gradient_checkpointing=self.config["use_gradient_checkpointing"],
|
|
72
|
+
random_state=self.config["random_state"],
|
|
73
|
+
use_rslora=self.config["use_rslora"],
|
|
74
|
+
loftq_config=self.config["loftq_config"],
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def process_dataset(self, dataset_info):
|
|
78
|
+
dataset_name = dataset_info["name"]
|
|
79
|
+
split_type = dataset_info.get("split_type", "train")
|
|
80
|
+
processing_func = getattr(self, dataset_info.get("processing_func", "format_prompts"))
|
|
81
|
+
rename = dataset_info.get("rename", {})
|
|
82
|
+
filter_data = dataset_info.get("filter_data", False)
|
|
83
|
+
filter_column_value = dataset_info.get("filter_column_value", "id")
|
|
84
|
+
filter_value = dataset_info.get("filter_value", "alpaca")
|
|
85
|
+
num_samples = dataset_info.get("num_samples", 20000)
|
|
86
|
+
|
|
87
|
+
dataset = load_dataset(dataset_name, split=split_type)
|
|
88
|
+
|
|
89
|
+
if rename:
|
|
90
|
+
dataset = dataset.rename_columns(rename)
|
|
91
|
+
if filter_data:
|
|
92
|
+
dataset = dataset.filter(lambda example: filter_value in example[filter_column_value]).shuffle(seed=42).select(range(num_samples))
|
|
93
|
+
dataset = dataset.map(processing_func, batched=True)
|
|
94
|
+
return dataset
|
|
95
|
+
|
|
96
|
+
def format_prompts(self, examples):
|
|
97
|
+
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
|
98
|
+
|
|
99
|
+
### Instruction:
|
|
100
|
+
{}
|
|
101
|
+
|
|
102
|
+
### Input:
|
|
103
|
+
{}
|
|
104
|
+
|
|
105
|
+
### Response:
|
|
106
|
+
{}"""
|
|
107
|
+
texts = [alpaca_prompt.format(ins, inp, out) + self.tokenizer.eos_token for ins, inp, out in zip(examples["instruction"], examples["input"], examples["output"])]
|
|
108
|
+
return {"text": texts}
|
|
109
|
+
|
|
110
|
+
def load_datasets(self):
|
|
111
|
+
datasets = []
|
|
112
|
+
for dataset_info in self.config["dataset"]:
|
|
113
|
+
datasets.append(self.process_dataset(dataset_info))
|
|
114
|
+
return concatenate_datasets(datasets)
|
|
115
|
+
|
|
116
|
+
def train_model(self):
|
|
117
|
+
dataset = self.load_datasets()
|
|
118
|
+
trainer = SFTTrainer(
|
|
119
|
+
model=self.model,
|
|
120
|
+
tokenizer=self.tokenizer,
|
|
121
|
+
train_dataset=dataset,
|
|
122
|
+
dataset_text_field=self.config["dataset_text_field"],
|
|
123
|
+
max_seq_length=self.config["max_seq_length"],
|
|
124
|
+
dataset_num_proc=self.config["dataset_num_proc"],
|
|
125
|
+
packing=self.config["packing"],
|
|
126
|
+
args=TrainingArguments(
|
|
127
|
+
per_device_train_batch_size=self.config["per_device_train_batch_size"],
|
|
128
|
+
gradient_accumulation_steps=self.config["gradient_accumulation_steps"],
|
|
129
|
+
warmup_steps=self.config["warmup_steps"],
|
|
130
|
+
num_train_epochs=self.config["num_train_epochs"],
|
|
131
|
+
max_steps=self.config["max_steps"],
|
|
132
|
+
learning_rate=self.config["learning_rate"],
|
|
133
|
+
fp16=not is_bfloat16_supported(),
|
|
134
|
+
bf16=is_bfloat16_supported(),
|
|
135
|
+
logging_steps=self.config["logging_steps"],
|
|
136
|
+
optim=self.config["optim"],
|
|
137
|
+
weight_decay=self.config["weight_decay"],
|
|
138
|
+
lr_scheduler_type=self.config["lr_scheduler_type"],
|
|
139
|
+
seed=self.config["seed"],
|
|
140
|
+
output_dir=self.config["output_dir"],
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
trainer.train()
|
|
144
|
+
|
|
145
|
+
def inference(self, instruction, input_text):
|
|
146
|
+
FastLanguageModel.for_inference(self.model)
|
|
147
|
+
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
|
148
|
+
|
|
149
|
+
### Instruction:
|
|
150
|
+
{}
|
|
151
|
+
|
|
152
|
+
### Input:
|
|
153
|
+
{}
|
|
154
|
+
|
|
155
|
+
### Response:
|
|
156
|
+
{}"""
|
|
157
|
+
inputs = self.tokenizer([alpaca_prompt.format(instruction, input_text, "")], return_tensors="pt").to("cuda")
|
|
158
|
+
outputs = self.model.generate(**inputs, max_new_tokens=64, use_cache=True)
|
|
159
|
+
print(self.tokenizer.batch_decode(outputs))
|
|
160
|
+
|
|
161
|
+
def save_model_merged(self):
|
|
162
|
+
if os.path.exists(self.config["hf_model_name"]):
|
|
163
|
+
shutil.rmtree(self.config["hf_model_name"])
|
|
164
|
+
self.model.push_to_hub_merged(
|
|
165
|
+
self.config["hf_model_name"],
|
|
166
|
+
self.tokenizer,
|
|
167
|
+
save_method="merged_16bit",
|
|
168
|
+
token=os.getenv('HF_TOKEN')
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def push_model_gguf(self):
|
|
172
|
+
self.model.push_to_hub_gguf(
|
|
173
|
+
self.config["hf_model_name"],
|
|
174
|
+
self.tokenizer,
|
|
175
|
+
quantization_method=self.config["quantization_method"],
|
|
176
|
+
token=os.getenv('HF_TOKEN')
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def prepare_modelfile_content(self):
|
|
180
|
+
output_model = self.config["hf_model_name"]
|
|
181
|
+
return f"""FROM {output_model}/unsloth.Q5_K_M.gguf
|
|
182
|
+
|
|
183
|
+
TEMPLATE \"\"\"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.{{{{ if .Prompt }}}}
|
|
184
|
+
|
|
185
|
+
### Instruction:
|
|
186
|
+
{{{{ .Prompt }}}}
|
|
187
|
+
|
|
188
|
+
{{{{ end }}}}### Response:
|
|
189
|
+
{{{{ .Response }}}}\"\"\"
|
|
190
|
+
|
|
191
|
+
PARAMETER stop ""
|
|
192
|
+
PARAMETER stop ""
|
|
193
|
+
PARAMETER stop ""
|
|
194
|
+
PARAMETER stop ""
|
|
195
|
+
PARAMETER stop "<|reserved_special_token_"
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def create_and_push_ollama_model(self):
|
|
199
|
+
modelfile_content = self.prepare_modelfile_content()
|
|
200
|
+
with open('Modelfile', 'w') as file:
|
|
201
|
+
file.write(modelfile_content)
|
|
202
|
+
|
|
203
|
+
subprocess.run(["ollama", "serve"])
|
|
204
|
+
subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"])
|
|
205
|
+
subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"])
|
|
206
|
+
|
|
207
|
+
def run(self):
|
|
208
|
+
self.print_system_info()
|
|
209
|
+
self.check_gpu()
|
|
210
|
+
self.check_ram()
|
|
211
|
+
# self.install_packages()
|
|
212
|
+
self.prepare_model()
|
|
213
|
+
self.train_model()
|
|
214
|
+
self.save_model_merged()
|
|
215
|
+
self.push_model_gguf()
|
|
216
|
+
self.create_and_push_ollama_model()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def main():
|
|
220
|
+
import argparse
|
|
221
|
+
parser = argparse.ArgumentParser(description='PraisonAI Training Script')
|
|
222
|
+
parser.add_argument('command', choices=['train'], help='Command to execute')
|
|
223
|
+
parser.add_argument('--config', default='config.yaml', help='Path to configuration file')
|
|
224
|
+
args = parser.parse_args()
|
|
225
|
+
|
|
226
|
+
if args.command == 'train':
|
|
227
|
+
ai = train(config_path=args.config)
|
|
228
|
+
ai.run()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
if __name__ == '__main__':
|
|
232
|
+
main()
|
praisonai/ui/chat.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import chainlit as cl
|
|
2
|
+
from chainlit.input_widget import TextInput
|
|
3
|
+
from chainlit.types import ThreadDict
|
|
4
|
+
from litellm import acompletion
|
|
5
|
+
import os
|
|
6
|
+
import sqlite3
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Dict, List, Optional
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
load_dotenv()
|
|
11
|
+
import chainlit.data as cl_data
|
|
12
|
+
from chainlit.step import StepDict
|
|
13
|
+
from literalai.helper import utc_now
|
|
14
|
+
import logging
|
|
15
|
+
import json
|
|
16
|
+
from sql_alchemy import SQLAlchemyDataLayer
|
|
17
|
+
|
|
18
|
+
# Set up logging
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
log_level = os.getenv("LOGLEVEL", "INFO").upper()
|
|
21
|
+
logger.handlers = []
|
|
22
|
+
|
|
23
|
+
# Set up logging to console
|
|
24
|
+
console_handler = logging.StreamHandler()
|
|
25
|
+
console_handler.setLevel(log_level)
|
|
26
|
+
console_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
27
|
+
console_handler.setFormatter(console_formatter)
|
|
28
|
+
logger.addHandler(console_handler)
|
|
29
|
+
|
|
30
|
+
# Set the logging level for the logger
|
|
31
|
+
logger.setLevel(log_level)
|
|
32
|
+
|
|
33
|
+
CHAINLIT_AUTH_SECRET = os.getenv("CHAINLIT_AUTH_SECRET")
|
|
34
|
+
|
|
35
|
+
if not CHAINLIT_AUTH_SECRET:
|
|
36
|
+
os.environ["CHAINLIT_AUTH_SECRET"] = "p8BPhQChpg@J>jBz$wGxqLX2V>yTVgP*7Ky9H$aV:axW~ANNX-7_T:o@lnyCBu^U"
|
|
37
|
+
CHAINLIT_AUTH_SECRET = os.getenv("CHAINLIT_AUTH_SECRET")
|
|
38
|
+
|
|
39
|
+
now = utc_now()
|
|
40
|
+
|
|
41
|
+
create_step_counter = 0
|
|
42
|
+
|
|
43
|
+
DB_PATH = os.path.expanduser("~/.praison/database.sqlite")
|
|
44
|
+
|
|
45
|
+
def initialize_db():
|
|
46
|
+
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
47
|
+
conn = sqlite3.connect(DB_PATH)
|
|
48
|
+
cursor = conn.cursor()
|
|
49
|
+
cursor.execute('''
|
|
50
|
+
CREATE TABLE IF NOT EXISTS users (
|
|
51
|
+
id UUID PRIMARY KEY,
|
|
52
|
+
identifier TEXT NOT NULL UNIQUE,
|
|
53
|
+
metadata JSONB NOT NULL,
|
|
54
|
+
createdAt TEXT
|
|
55
|
+
)
|
|
56
|
+
''')
|
|
57
|
+
cursor.execute('''
|
|
58
|
+
CREATE TABLE IF NOT EXISTS threads (
|
|
59
|
+
id UUID PRIMARY KEY,
|
|
60
|
+
createdAt TEXT,
|
|
61
|
+
name TEXT,
|
|
62
|
+
userId UUID,
|
|
63
|
+
userIdentifier TEXT,
|
|
64
|
+
tags TEXT[],
|
|
65
|
+
metadata JSONB NOT NULL DEFAULT '{}',
|
|
66
|
+
FOREIGN KEY (userId) REFERENCES users(id) ON DELETE CASCADE
|
|
67
|
+
)
|
|
68
|
+
''')
|
|
69
|
+
cursor.execute('''
|
|
70
|
+
CREATE TABLE IF NOT EXISTS steps (
|
|
71
|
+
id UUID PRIMARY KEY,
|
|
72
|
+
name TEXT NOT NULL,
|
|
73
|
+
type TEXT NOT NULL,
|
|
74
|
+
threadId UUID NOT NULL,
|
|
75
|
+
parentId UUID,
|
|
76
|
+
disableFeedback BOOLEAN NOT NULL,
|
|
77
|
+
streaming BOOLEAN NOT NULL,
|
|
78
|
+
waitForAnswer BOOLEAN,
|
|
79
|
+
isError BOOLEAN,
|
|
80
|
+
metadata JSONB,
|
|
81
|
+
tags TEXT[],
|
|
82
|
+
input TEXT,
|
|
83
|
+
output TEXT,
|
|
84
|
+
createdAt TEXT,
|
|
85
|
+
start TEXT,
|
|
86
|
+
end TEXT,
|
|
87
|
+
generation JSONB,
|
|
88
|
+
showInput TEXT,
|
|
89
|
+
language TEXT,
|
|
90
|
+
indent INT,
|
|
91
|
+
FOREIGN KEY (threadId) REFERENCES threads (id) ON DELETE CASCADE
|
|
92
|
+
)
|
|
93
|
+
''')
|
|
94
|
+
cursor.execute('''
|
|
95
|
+
CREATE TABLE IF NOT EXISTS elements (
|
|
96
|
+
id UUID PRIMARY KEY,
|
|
97
|
+
threadId UUID,
|
|
98
|
+
type TEXT,
|
|
99
|
+
url TEXT,
|
|
100
|
+
chainlitKey TEXT,
|
|
101
|
+
name TEXT NOT NULL,
|
|
102
|
+
display TEXT,
|
|
103
|
+
objectKey TEXT,
|
|
104
|
+
size TEXT,
|
|
105
|
+
page INT,
|
|
106
|
+
language TEXT,
|
|
107
|
+
forId UUID,
|
|
108
|
+
mime TEXT,
|
|
109
|
+
FOREIGN KEY (threadId) REFERENCES threads (id) ON DELETE CASCADE
|
|
110
|
+
)
|
|
111
|
+
''')
|
|
112
|
+
cursor.execute('''
|
|
113
|
+
CREATE TABLE IF NOT EXISTS feedbacks (
|
|
114
|
+
id UUID PRIMARY KEY,
|
|
115
|
+
forId UUID NOT NULL,
|
|
116
|
+
value INT NOT NULL,
|
|
117
|
+
threadId UUID,
|
|
118
|
+
comment TEXT
|
|
119
|
+
)
|
|
120
|
+
''')
|
|
121
|
+
cursor.execute('''
|
|
122
|
+
CREATE TABLE IF NOT EXISTS settings (
|
|
123
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
124
|
+
key TEXT UNIQUE,
|
|
125
|
+
value TEXT
|
|
126
|
+
)
|
|
127
|
+
''')
|
|
128
|
+
conn.commit()
|
|
129
|
+
conn.close()
|
|
130
|
+
|
|
131
|
+
def save_setting(key: str, value: str):
|
|
132
|
+
"""Saves a setting to the database.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
key: The setting key.
|
|
136
|
+
value: The setting value.
|
|
137
|
+
"""
|
|
138
|
+
conn = sqlite3.connect(DB_PATH)
|
|
139
|
+
cursor = conn.cursor()
|
|
140
|
+
cursor.execute(
|
|
141
|
+
"""
|
|
142
|
+
INSERT OR REPLACE INTO settings (id, key, value)
|
|
143
|
+
VALUES ((SELECT id FROM settings WHERE key = ?), ?, ?)
|
|
144
|
+
""",
|
|
145
|
+
(key, key, value),
|
|
146
|
+
)
|
|
147
|
+
conn.commit()
|
|
148
|
+
conn.close()
|
|
149
|
+
|
|
150
|
+
def load_setting(key: str) -> str:
|
|
151
|
+
"""Loads a setting from the database.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
key: The setting key.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
The setting value, or None if the key is not found.
|
|
158
|
+
"""
|
|
159
|
+
conn = sqlite3.connect(DB_PATH)
|
|
160
|
+
cursor = conn.cursor()
|
|
161
|
+
cursor.execute('SELECT value FROM settings WHERE key = ?', (key,))
|
|
162
|
+
result = cursor.fetchone()
|
|
163
|
+
conn.close()
|
|
164
|
+
return result[0] if result else None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# Initialize the database
|
|
168
|
+
initialize_db()
|
|
169
|
+
|
|
170
|
+
deleted_thread_ids = [] # type: List[str]
|
|
171
|
+
|
|
172
|
+
cl_data._data_layer = SQLAlchemyDataLayer(conninfo=f"sqlite+aiosqlite:///{DB_PATH}")
|
|
173
|
+
|
|
174
|
+
@cl.on_chat_start
|
|
175
|
+
async def start():
|
|
176
|
+
initialize_db()
|
|
177
|
+
model_name = load_setting("model_name")
|
|
178
|
+
|
|
179
|
+
if model_name:
|
|
180
|
+
cl.user_session.set("model_name", model_name)
|
|
181
|
+
else:
|
|
182
|
+
# If no setting found, use default or environment variable
|
|
183
|
+
model_name = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
|
184
|
+
cl.user_session.set("model_name", model_name)
|
|
185
|
+
logger.debug(f"Model name: {model_name}")
|
|
186
|
+
settings = cl.ChatSettings(
|
|
187
|
+
[
|
|
188
|
+
TextInput(
|
|
189
|
+
id="model_name",
|
|
190
|
+
label="Enter the Model Name",
|
|
191
|
+
placeholder="e.g., gpt-4o-mini",
|
|
192
|
+
initial=model_name
|
|
193
|
+
)
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
cl.user_session.set("settings", settings)
|
|
197
|
+
await settings.send()
|
|
198
|
+
|
|
199
|
+
@cl.on_settings_update
|
|
200
|
+
async def setup_agent(settings):
|
|
201
|
+
logger.debug(settings)
|
|
202
|
+
cl.user_session.set("settings", settings)
|
|
203
|
+
model_name = settings["model_name"]
|
|
204
|
+
cl.user_session.set("model_name", model_name)
|
|
205
|
+
|
|
206
|
+
# Save in settings table
|
|
207
|
+
save_setting("model_name", model_name)
|
|
208
|
+
|
|
209
|
+
# Save in thread metadata
|
|
210
|
+
thread_id = cl.user_session.get("thread_id")
|
|
211
|
+
if thread_id:
|
|
212
|
+
thread = await cl_data.get_thread(thread_id)
|
|
213
|
+
if thread:
|
|
214
|
+
metadata = thread.get("metadata", {})
|
|
215
|
+
metadata["model_name"] = model_name
|
|
216
|
+
|
|
217
|
+
# Always store metadata as a JSON string
|
|
218
|
+
await cl_data.update_thread(thread_id, metadata=json.dumps(metadata))
|
|
219
|
+
|
|
220
|
+
# Update the user session with the new metadata
|
|
221
|
+
cl.user_session.set("metadata", metadata)
|
|
222
|
+
|
|
223
|
+
@cl.on_message
|
|
224
|
+
async def main(message: cl.Message):
|
|
225
|
+
model_name = load_setting("model_name") or os.getenv("MODEL_NAME") or "gpt-4o-mini"
|
|
226
|
+
message_history = cl.user_session.get("message_history", [])
|
|
227
|
+
message_history.append({"role": "user", "content": message.content})
|
|
228
|
+
|
|
229
|
+
msg = cl.Message(content="")
|
|
230
|
+
await msg.send()
|
|
231
|
+
|
|
232
|
+
response = await acompletion(
|
|
233
|
+
model=model_name,
|
|
234
|
+
messages=message_history,
|
|
235
|
+
stream=True,
|
|
236
|
+
# temperature=0.7,
|
|
237
|
+
# max_tokens=500,
|
|
238
|
+
# top_p=1
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
full_response = ""
|
|
242
|
+
async for part in response:
|
|
243
|
+
if token := part['choices'][0]['delta']['content']:
|
|
244
|
+
await msg.stream_token(token)
|
|
245
|
+
full_response += token
|
|
246
|
+
logger.debug(f"Full response: {full_response}")
|
|
247
|
+
message_history.append({"role": "assistant", "content": full_response})
|
|
248
|
+
logger.debug(f"Message history: {message_history}")
|
|
249
|
+
cl.user_session.set("message_history", message_history)
|
|
250
|
+
await msg.update()
|
|
251
|
+
|
|
252
|
+
username = os.getenv("CHAINLIT_USERNAME", "admin") # Default to "admin" if not found
|
|
253
|
+
password = os.getenv("CHAINLIT_PASSWORD", "admin") # Default to "admin" if not found
|
|
254
|
+
|
|
255
|
+
@cl.password_auth_callback
|
|
256
|
+
def auth_callback(username: str, password: str):
|
|
257
|
+
if (username, password) == (username, password):
|
|
258
|
+
return cl.User(
|
|
259
|
+
identifier=username, metadata={"role": "ADMIN", "provider": "credentials"}
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
async def send_count():
|
|
265
|
+
await cl.Message(
|
|
266
|
+
f"Create step counter: {create_step_counter}", disable_feedback=True
|
|
267
|
+
).send()
|
|
268
|
+
|
|
269
|
+
@cl.on_chat_resume
|
|
270
|
+
async def on_chat_resume(thread: cl_data.ThreadDict):
|
|
271
|
+
logger.info(f"Resuming chat: {thread['id']}")
|
|
272
|
+
model_name = load_setting("model_name") or os.getenv("MODEL_NAME") or "gpt-4o-mini"
|
|
273
|
+
logger.debug(f"Model name: {model_name}")
|
|
274
|
+
settings = cl.ChatSettings(
|
|
275
|
+
[
|
|
276
|
+
TextInput(
|
|
277
|
+
id="model_name",
|
|
278
|
+
label="Enter the Model Name",
|
|
279
|
+
placeholder="e.g., gpt-4o-mini",
|
|
280
|
+
initial=model_name
|
|
281
|
+
)
|
|
282
|
+
]
|
|
283
|
+
)
|
|
284
|
+
await settings.send()
|
|
285
|
+
thread_id = thread["id"]
|
|
286
|
+
cl.user_session.set("thread_id", thread["id"])
|
|
287
|
+
|
|
288
|
+
# The metadata should now already be a dictionary
|
|
289
|
+
metadata = thread.get("metadata", {})
|
|
290
|
+
cl.user_session.set("metadata", metadata)
|
|
291
|
+
|
|
292
|
+
message_history = cl.user_session.get("message_history", [])
|
|
293
|
+
steps = thread["steps"]
|
|
294
|
+
|
|
295
|
+
for message in steps:
|
|
296
|
+
msg_type = message.get("type")
|
|
297
|
+
if msg_type == "user_message":
|
|
298
|
+
message_history.append({"role": "user", "content": message.get("output", "")})
|
|
299
|
+
elif msg_type == "assistant_message":
|
|
300
|
+
message_history.append({"role": "assistant", "content": message.get("output", "")})
|
|
301
|
+
else:
|
|
302
|
+
logger.warning(f"Message without type: {message}")
|
|
303
|
+
|
|
304
|
+
cl.user_session.set("message_history", message_history)
|