gptmed 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/framework/cli/__init__.py +18 -2
- gptmed/framework/cli/__main__.py +7 -0
- gptmed/framework/cli/startproject.py +845 -7
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/METADATA +1 -1
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/RECORD +9 -8
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/WHEEL +0 -0
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/entry_points.txt +0 -0
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.5.3.dist-info → gptmed-0.5.5.dist-info}/top_level.txt +0 -0
gptmed/framework/cli/__init__.py
CHANGED
|
@@ -3,6 +3,22 @@ from .startproject import startproject
|
|
|
3
3
|
|
|
4
4
|
def main():
|
|
5
5
|
if len(sys.argv) < 3 or sys.argv[1] != "startproject":
|
|
6
|
-
print("Usage: gptmed startproject <projectname>")
|
|
6
|
+
print("Usage: gptmed startproject <projectname> [--qna|--conversational]")
|
|
7
7
|
sys.exit(1)
|
|
8
|
-
|
|
8
|
+
|
|
9
|
+
project_name = sys.argv[2]
|
|
10
|
+
project_type = None
|
|
11
|
+
|
|
12
|
+
# Check for flags
|
|
13
|
+
if len(sys.argv) > 3:
|
|
14
|
+
flag = sys.argv[3]
|
|
15
|
+
if flag == "--qna":
|
|
16
|
+
project_type = "qna"
|
|
17
|
+
elif flag == "--conversational":
|
|
18
|
+
project_type = "conversational"
|
|
19
|
+
else:
|
|
20
|
+
print(f"Invalid flag: {flag}")
|
|
21
|
+
print("Usage: gptmed startproject <projectname> [--qna|--conversational]")
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
startproject(project_name, project_type)
|
|
@@ -1,17 +1,855 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import sys
|
|
3
3
|
|
|
4
|
-
def
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
4
|
+
def create_qna_templates(project_name):
|
|
5
|
+
"""Create boilerplate for QNA model training architecture"""
|
|
6
|
+
|
|
7
|
+
# Create directory structure
|
|
8
|
+
os.makedirs(os.path.join(project_name, "configs"))
|
|
9
|
+
os.makedirs(os.path.join(project_name, "data", "raw"))
|
|
10
|
+
os.makedirs(os.path.join(project_name, "data", "processed"))
|
|
11
|
+
os.makedirs(os.path.join(project_name, "models", "checkpoints"))
|
|
12
|
+
os.makedirs(os.path.join(project_name, "tokenizer"))
|
|
13
|
+
os.makedirs(os.path.join(project_name, "logs"))
|
|
14
|
+
os.makedirs(os.path.join(project_name, "inference"))
|
|
15
|
+
|
|
16
|
+
# Create main.py
|
|
17
|
+
with open(os.path.join(project_name, "main.py"), "w") as f:
|
|
18
|
+
f.write("""\"\"\"
|
|
19
|
+
Main entry point for QNA Model Training
|
|
20
|
+
\"\"\"
|
|
21
|
+
import gptmed
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
def main():
|
|
25
|
+
# Step 1: Create configuration
|
|
26
|
+
config_path = 'configs/training_config.yaml'
|
|
27
|
+
if not Path(config_path).exists():
|
|
28
|
+
gptmed.create_config(config_path)
|
|
29
|
+
print(f"Configuration file created at {config_path}")
|
|
30
|
+
print("Please edit the configuration file and run again.")
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
# Step 2: Train the model
|
|
34
|
+
print("Starting QNA model training...")
|
|
35
|
+
results = gptmed.train_from_config(config_path, device='auto')
|
|
36
|
+
|
|
37
|
+
print(f"\\nTraining completed!")
|
|
38
|
+
print(f"Best checkpoint: {results['best_checkpoint']}")
|
|
39
|
+
print(f"Final validation loss: {results['final_val_loss']}")
|
|
40
|
+
|
|
41
|
+
if __name__ == "__main__":
|
|
42
|
+
main()
|
|
43
|
+
""")
|
|
44
|
+
|
|
45
|
+
# Create preprocess.py
|
|
46
|
+
with open(os.path.join(project_name, "preprocess.py"), "w") as f:
|
|
47
|
+
f.write("""\"\"\"
|
|
48
|
+
Data preprocessing for QNA dataset
|
|
49
|
+
\"\"\"
|
|
50
|
+
import json
|
|
51
|
+
from pathlib import Path
|
|
52
|
+
|
|
53
|
+
def preprocess_qna_data(input_file, output_file):
|
|
54
|
+
\"\"\"
|
|
55
|
+
Preprocess QNA data from raw format to training format.
|
|
56
|
+
|
|
57
|
+
Expected input format (JSONL):
|
|
58
|
+
{"question": "What is X?", "answer": "X is..."}
|
|
59
|
+
|
|
60
|
+
Output format:
|
|
61
|
+
{"text": "Q: What is X?\\nA: X is..."}
|
|
62
|
+
\"\"\"
|
|
63
|
+
processed_data = []
|
|
64
|
+
|
|
65
|
+
with open(input_file, 'r', encoding='utf-8') as f:
|
|
66
|
+
for line in f:
|
|
67
|
+
item = json.loads(line)
|
|
68
|
+
question = item.get('question', '').strip()
|
|
69
|
+
answer = item.get('answer', '').strip()
|
|
70
|
+
|
|
71
|
+
if question and answer:
|
|
72
|
+
formatted_text = f"Q: {question}\\nA: {answer}"
|
|
73
|
+
processed_data.append({"text": formatted_text})
|
|
74
|
+
|
|
75
|
+
# Save processed data
|
|
76
|
+
output_path = Path(output_file)
|
|
77
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
|
|
79
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
|
80
|
+
for item in processed_data:
|
|
81
|
+
f.write(json.dumps(item) + '\\n')
|
|
82
|
+
|
|
83
|
+
print(f"Processed {len(processed_data)} QNA pairs")
|
|
84
|
+
print(f"Saved to {output_file}")
|
|
85
|
+
|
|
86
|
+
if __name__ == "__main__":
|
|
87
|
+
# Example usage
|
|
88
|
+
preprocess_qna_data(
|
|
89
|
+
'data/raw/qna_data.jsonl',
|
|
90
|
+
'data/processed/train.jsonl'
|
|
91
|
+
)
|
|
92
|
+
""")
|
|
93
|
+
|
|
94
|
+
# Create inference script
|
|
95
|
+
with open(os.path.join(project_name, "inference", "generate_answer.py"), "w") as f:
|
|
96
|
+
f.write("""\"\"\"
|
|
97
|
+
Generate answers using trained QNA model
|
|
98
|
+
\"\"\"
|
|
99
|
+
import gptmed
|
|
100
|
+
from pathlib import Path
|
|
101
|
+
|
|
102
|
+
def generate_answer(question, checkpoint_path, tokenizer_path, max_length=200):
|
|
103
|
+
\"\"\"Generate answer for a given question\"\"\"
|
|
104
|
+
|
|
105
|
+
# Format question
|
|
106
|
+
prompt = f"Q: {question}\\nA:"
|
|
107
|
+
|
|
108
|
+
# Generate answer
|
|
109
|
+
answer = gptmed.generate(
|
|
110
|
+
checkpoint=checkpoint_path,
|
|
111
|
+
prompt=prompt,
|
|
112
|
+
tokenizer=tokenizer_path,
|
|
113
|
+
max_length=max_length,
|
|
114
|
+
temperature=0.7,
|
|
115
|
+
top_k=50
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Extract answer (remove the prompt)
|
|
119
|
+
answer = answer.replace(prompt, '').strip()
|
|
120
|
+
|
|
121
|
+
return answer
|
|
122
|
+
|
|
123
|
+
if __name__ == "__main__":
|
|
124
|
+
# Example usage
|
|
125
|
+
checkpoint = "models/checkpoints/best_model.pt"
|
|
126
|
+
tokenizer = "tokenizer/qna_tokenizer.model"
|
|
127
|
+
|
|
128
|
+
question = "What is machine learning?"
|
|
129
|
+
answer = generate_answer(question, checkpoint, tokenizer)
|
|
130
|
+
|
|
131
|
+
print(f"Question: {question}")
|
|
132
|
+
print(f"Answer: {answer}")
|
|
133
|
+
""")
|
|
134
|
+
|
|
135
|
+
# Create README.md
|
|
136
|
+
with open(os.path.join(project_name, "README.md"), "w") as f:
|
|
137
|
+
f.write(f"""# {project_name} - QNA Model
|
|
138
|
+
|
|
139
|
+
Question and Answer generation model training architecture.
|
|
140
|
+
|
|
141
|
+
## Directory Structure
|
|
142
|
+
|
|
143
|
+
```
|
|
144
|
+
{project_name}/
|
|
145
|
+
├── configs/ # Training configurations
|
|
146
|
+
├── data/
|
|
147
|
+
│ ├── raw/ # Raw QNA data
|
|
148
|
+
│ └── processed/ # Preprocessed data
|
|
149
|
+
├── models/
|
|
150
|
+
│ └── checkpoints/ # Model checkpoints
|
|
151
|
+
├── tokenizer/ # Tokenizer files
|
|
152
|
+
├── logs/ # Training logs
|
|
153
|
+
├── inference/ # Inference scripts
|
|
154
|
+
├── main.py # Main training script
|
|
155
|
+
├── preprocess.py # Data preprocessing
|
|
156
|
+
└── README.md
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
## Getting Started
|
|
160
|
+
|
|
161
|
+
### 1. Prepare Data
|
|
162
|
+
|
|
163
|
+
Place your QNA data in JSONL format in `data/raw/qna_data.jsonl`:
|
|
164
|
+
|
|
165
|
+
```json
|
|
166
|
+
{{"question": "What is X?", "answer": "X is..."}}
|
|
167
|
+
{{"question": "How does Y work?", "answer": "Y works by..."}}
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
### 2. Preprocess Data
|
|
171
|
+
|
|
172
|
+
```bash
|
|
173
|
+
python preprocess.py
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
### 3. Configure Training
|
|
177
|
+
|
|
178
|
+
```bash
|
|
179
|
+
python main.py # This will create a default config file
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
Edit `configs/training_config.yaml` with your settings.
|
|
183
|
+
|
|
184
|
+
### 4. Train Model
|
|
185
|
+
|
|
186
|
+
```bash
|
|
187
|
+
python main.py
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### 5. Generate Answers
|
|
191
|
+
|
|
192
|
+
```bash
|
|
193
|
+
python inference/generate_answer.py
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
## Training Configuration
|
|
197
|
+
|
|
198
|
+
Edit `configs/training_config.yaml` to customize:
|
|
199
|
+
- Model size (tiny, small, medium)
|
|
200
|
+
- Training parameters (epochs, batch size, learning rate)
|
|
201
|
+
- Data paths
|
|
202
|
+
- Device selection (CPU/GPU)
|
|
203
|
+
|
|
204
|
+
## Inference
|
|
205
|
+
|
|
206
|
+
Use the trained model to generate answers:
|
|
207
|
+
|
|
208
|
+
```python
|
|
209
|
+
import gptmed
|
|
210
|
+
|
|
211
|
+
answer = gptmed.generate(
|
|
212
|
+
checkpoint='models/checkpoints/best_model.pt',
|
|
213
|
+
prompt='Q: Your question?\\nA:',
|
|
214
|
+
tokenizer='tokenizer/qna_tokenizer.model'
|
|
215
|
+
)
|
|
216
|
+
```
|
|
217
|
+
""")
|
|
218
|
+
|
|
219
|
+
def create_conversational_templates(project_name):
|
|
220
|
+
"""Create boilerplate for conversational model training architecture"""
|
|
221
|
+
|
|
222
|
+
# Create directory structure
|
|
223
|
+
os.makedirs(os.path.join(project_name, "configs"))
|
|
224
|
+
os.makedirs(os.path.join(project_name, "data", "raw"))
|
|
225
|
+
os.makedirs(os.path.join(project_name, "data", "processed"))
|
|
226
|
+
os.makedirs(os.path.join(project_name, "models", "checkpoints"))
|
|
227
|
+
os.makedirs(os.path.join(project_name, "tokenizer"))
|
|
228
|
+
os.makedirs(os.path.join(project_name, "logs"))
|
|
229
|
+
os.makedirs(os.path.join(project_name, "inference"))
|
|
230
|
+
os.makedirs(os.path.join(project_name, "utils"))
|
|
231
|
+
|
|
232
|
+
# Create main.py
|
|
233
|
+
with open(os.path.join(project_name, "main.py"), "w") as f:
|
|
234
|
+
f.write("""\"\"\"
|
|
235
|
+
Main entry point for Conversational Model Training
|
|
236
|
+
\"\"\"
|
|
237
|
+
import gptmed
|
|
238
|
+
from pathlib import Path
|
|
239
|
+
|
|
240
|
+
def main():
|
|
241
|
+
# Step 1: Create configuration
|
|
242
|
+
config_path = 'configs/training_config.yaml'
|
|
243
|
+
if not Path(config_path).exists():
|
|
244
|
+
gptmed.create_config(config_path)
|
|
245
|
+
print(f"Configuration file created at {config_path}")
|
|
246
|
+
print("Please edit the configuration file and run again.")
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# Step 2: Train the model
|
|
250
|
+
print("Starting conversational model training...")
|
|
251
|
+
results = gptmed.train_from_config(config_path, device='auto')
|
|
252
|
+
|
|
253
|
+
print(f"\\nTraining completed!")
|
|
254
|
+
print(f"Best checkpoint: {results['best_checkpoint']}")
|
|
255
|
+
print(f"Final validation loss: {results['final_val_loss']}")
|
|
256
|
+
|
|
257
|
+
if __name__ == "__main__":
|
|
258
|
+
main()
|
|
259
|
+
""")
|
|
260
|
+
|
|
261
|
+
# Create preprocess.py for conversational data
|
|
262
|
+
with open(os.path.join(project_name, "preprocess.py"), "w") as f:
|
|
263
|
+
f.write("""\"\"\"
|
|
264
|
+
Data preprocessing for conversational dataset
|
|
265
|
+
\"\"\"
|
|
266
|
+
import json
|
|
267
|
+
from pathlib import Path
|
|
268
|
+
from typing import List, Dict
|
|
269
|
+
|
|
270
|
+
def format_conversation(messages: List[Dict[str, str]],
|
|
271
|
+
user_token="<|user|>",
|
|
272
|
+
assistant_token="<|assistant|>",
|
|
273
|
+
end_token="<|endoftext|>") -> str:
|
|
274
|
+
\"\"\"
|
|
275
|
+
Format a conversation into training text.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
messages: List of message dicts with 'role' and 'content'
|
|
279
|
+
user_token: Token to mark user messages
|
|
280
|
+
assistant_token: Token to mark assistant messages
|
|
281
|
+
end_token: Token to mark end of conversation
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Formatted conversation string
|
|
285
|
+
\"\"\"
|
|
286
|
+
conversation = []
|
|
287
|
+
|
|
288
|
+
for msg in messages:
|
|
289
|
+
role = msg.get('role', '')
|
|
290
|
+
content = msg.get('content', '').strip()
|
|
291
|
+
|
|
292
|
+
if role == 'user':
|
|
293
|
+
conversation.append(f"{user_token} {content}")
|
|
294
|
+
elif role == 'assistant':
|
|
295
|
+
conversation.append(f"{assistant_token} {content}")
|
|
296
|
+
|
|
297
|
+
return "\\n".join(conversation) + f"\\n{end_token}"
|
|
298
|
+
|
|
299
|
+
def preprocess_conversational_data(input_file, output_file):
|
|
300
|
+
\"\"\"
|
|
301
|
+
Preprocess conversational data from raw format to training format.
|
|
302
|
+
|
|
303
|
+
Expected input format (JSONL):
|
|
304
|
+
{
|
|
305
|
+
"conversation": [
|
|
306
|
+
{"role": "user", "content": "Hello!"},
|
|
307
|
+
{"role": "assistant", "content": "Hi! How can I help?"},
|
|
308
|
+
{"role": "user", "content": "Tell me about AI"},
|
|
309
|
+
{"role": "assistant", "content": "AI is..."}
|
|
310
|
+
]
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
Output format:
|
|
314
|
+
{"text": "<|user|> Hello!\\n<|assistant|> Hi! How can I help?\\n..."}
|
|
315
|
+
\"\"\"
|
|
316
|
+
processed_data = []
|
|
317
|
+
|
|
318
|
+
with open(input_file, 'r', encoding='utf-8') as f:
|
|
319
|
+
for line in f:
|
|
320
|
+
item = json.loads(line)
|
|
321
|
+
conversation = item.get('conversation', [])
|
|
322
|
+
|
|
323
|
+
if len(conversation) >= 2: # At least one exchange
|
|
324
|
+
formatted_text = format_conversation(conversation)
|
|
325
|
+
processed_data.append({"text": formatted_text})
|
|
326
|
+
|
|
327
|
+
# Save processed data
|
|
328
|
+
output_path = Path(output_file)
|
|
329
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
330
|
+
|
|
331
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
|
332
|
+
for item in processed_data:
|
|
333
|
+
f.write(json.dumps(item) + '\\n')
|
|
334
|
+
|
|
335
|
+
print(f"Processed {len(processed_data)} conversations")
|
|
336
|
+
print(f"Saved to {output_file}")
|
|
337
|
+
|
|
338
|
+
if __name__ == "__main__":
|
|
339
|
+
# Example usage
|
|
340
|
+
preprocess_conversational_data(
|
|
341
|
+
'data/raw/conversations.jsonl',
|
|
342
|
+
'data/processed/train.jsonl'
|
|
343
|
+
)
|
|
344
|
+
""")
|
|
345
|
+
|
|
346
|
+
# Create conversation handler utility
|
|
347
|
+
with open(os.path.join(project_name, "utils", "conversation_handler.py"), "w") as f:
|
|
348
|
+
f.write("""\"\"\"
|
|
349
|
+
Conversation management utilities
|
|
350
|
+
\"\"\"
|
|
351
|
+
from typing import List, Dict, Optional
|
|
352
|
+
|
|
353
|
+
class ConversationHistory:
|
|
354
|
+
\"\"\"Manage conversation history for multi-turn dialogues\"\"\"
|
|
355
|
+
|
|
356
|
+
def __init__(self, max_history: int = 10):
|
|
357
|
+
\"\"\"
|
|
358
|
+
Initialize conversation history.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
max_history: Maximum number of turns to keep in history
|
|
362
|
+
\"\"\"
|
|
363
|
+
self.messages: List[Dict[str, str]] = []
|
|
364
|
+
self.max_history = max_history
|
|
365
|
+
|
|
366
|
+
def add_user_message(self, content: str):
|
|
367
|
+
\"\"\"Add a user message to history\"\"\"
|
|
368
|
+
self.messages.append({"role": "user", "content": content})
|
|
369
|
+
self._trim_history()
|
|
370
|
+
|
|
371
|
+
def add_assistant_message(self, content: str):
|
|
372
|
+
\"\"\"Add an assistant message to history\"\"\"
|
|
373
|
+
self.messages.append({"role": "assistant", "content": content})
|
|
374
|
+
self._trim_history()
|
|
375
|
+
|
|
376
|
+
def _trim_history(self):
|
|
377
|
+
\"\"\"Keep only the most recent messages\"\"\"
|
|
378
|
+
if len(self.messages) > self.max_history * 2: # 2 messages per turn
|
|
379
|
+
self.messages = self.messages[-(self.max_history * 2):]
|
|
380
|
+
|
|
381
|
+
def get_prompt(self,
|
|
382
|
+
user_token: str = "<|user|>",
|
|
383
|
+
assistant_token: str = "<|assistant|>") -> str:
|
|
384
|
+
\"\"\"
|
|
385
|
+
Generate prompt from conversation history.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Formatted conversation prompt
|
|
389
|
+
\"\"\"
|
|
390
|
+
prompt_parts = []
|
|
391
|
+
|
|
392
|
+
for msg in self.messages:
|
|
393
|
+
role = msg['role']
|
|
394
|
+
content = msg['content']
|
|
395
|
+
|
|
396
|
+
if role == 'user':
|
|
397
|
+
prompt_parts.append(f"{user_token} {content}")
|
|
398
|
+
elif role == 'assistant':
|
|
399
|
+
prompt_parts.append(f"{assistant_token} {content}")
|
|
400
|
+
|
|
401
|
+
# Add assistant token for next response
|
|
402
|
+
prompt_parts.append(assistant_token)
|
|
403
|
+
|
|
404
|
+
return "\\n".join(prompt_parts) + " "
|
|
405
|
+
|
|
406
|
+
def clear(self):
|
|
407
|
+
\"\"\"Clear conversation history\"\"\"
|
|
408
|
+
self.messages = []
|
|
409
|
+
|
|
410
|
+
def get_last_n_turns(self, n: int) -> List[Dict[str, str]]:
|
|
411
|
+
\"\"\"Get last n conversation turns\"\"\"
|
|
412
|
+
return self.messages[-(n * 2):]
|
|
413
|
+
""")
|
|
414
|
+
|
|
415
|
+
# Create interactive chat script
|
|
416
|
+
with open(os.path.join(project_name, "inference", "interactive_chat.py"), "w") as f:
|
|
417
|
+
f.write("""\"\"\"
|
|
418
|
+
Interactive chat with conversational model
|
|
419
|
+
\"\"\"
|
|
420
|
+
import sys
|
|
421
|
+
sys.path.append('..')
|
|
422
|
+
|
|
423
|
+
import gptmed
|
|
424
|
+
from pathlib import Path
|
|
425
|
+
from utils.conversation_handler import ConversationHistory
|
|
426
|
+
|
|
427
|
+
class ChatBot:
|
|
428
|
+
\"\"\"Interactive chatbot using trained conversational model\"\"\"
|
|
429
|
+
|
|
430
|
+
def __init__(self, checkpoint_path: str, tokenizer_path: str):
|
|
431
|
+
self.checkpoint_path = checkpoint_path
|
|
432
|
+
self.tokenizer_path = tokenizer_path
|
|
433
|
+
self.history = ConversationHistory(max_history=5)
|
|
434
|
+
self.user_token = "<|user|>"
|
|
435
|
+
self.assistant_token = "<|assistant|>"
|
|
436
|
+
self.end_token = "<|endoftext|>"
|
|
437
|
+
|
|
438
|
+
def generate_response(self, user_input: str, max_length: int = 150) -> str:
|
|
439
|
+
\"\"\"Generate response to user input\"\"\"
|
|
440
|
+
|
|
441
|
+
# Add user message to history
|
|
442
|
+
self.history.add_user_message(user_input)
|
|
443
|
+
|
|
444
|
+
# Get prompt from conversation history
|
|
445
|
+
prompt = self.history.get_prompt(self.user_token, self.assistant_token)
|
|
446
|
+
|
|
447
|
+
# Generate response
|
|
448
|
+
response = gptmed.generate(
|
|
449
|
+
checkpoint=self.checkpoint_path,
|
|
450
|
+
prompt=prompt,
|
|
451
|
+
tokenizer=self.tokenizer_path,
|
|
452
|
+
max_length=max_length,
|
|
453
|
+
temperature=0.8,
|
|
454
|
+
top_k=50,
|
|
455
|
+
top_p=0.9
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Extract assistant response
|
|
459
|
+
response = response.replace(prompt, '').strip()
|
|
460
|
+
|
|
461
|
+
# Remove end token if present
|
|
462
|
+
if self.end_token in response:
|
|
463
|
+
response = response.split(self.end_token)[0].strip()
|
|
464
|
+
|
|
465
|
+
# Remove user token if model generated it (shouldn't happen but just in case)
|
|
466
|
+
if self.user_token in response:
|
|
467
|
+
response = response.split(self.user_token)[0].strip()
|
|
468
|
+
|
|
469
|
+
# Add assistant response to history
|
|
470
|
+
self.history.add_assistant_message(response)
|
|
471
|
+
|
|
472
|
+
return response
|
|
473
|
+
|
|
474
|
+
def chat(self):
|
|
475
|
+
\"\"\"Start interactive chat session\"\"\"
|
|
476
|
+
print("Conversational AI Chatbot")
|
|
477
|
+
print("Type 'quit' or 'exit' to end the conversation")
|
|
478
|
+
print("Type 'clear' to reset conversation history")
|
|
479
|
+
print("-" * 50)
|
|
480
|
+
|
|
481
|
+
while True:
|
|
482
|
+
try:
|
|
483
|
+
user_input = input("You: ").strip()
|
|
484
|
+
|
|
485
|
+
if not user_input:
|
|
486
|
+
continue
|
|
487
|
+
|
|
488
|
+
if user_input.lower() in ['quit', 'exit']:
|
|
489
|
+
print("Goodbye!")
|
|
490
|
+
break
|
|
491
|
+
|
|
492
|
+
if user_input.lower() == 'clear':
|
|
493
|
+
self.history.clear()
|
|
494
|
+
print("Conversation history cleared.")
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
response = self.generate_response(user_input)
|
|
498
|
+
print(f"Bot: {response}")
|
|
499
|
+
|
|
500
|
+
except KeyboardInterrupt:
|
|
501
|
+
print("\\nGoodbye!")
|
|
502
|
+
break
|
|
503
|
+
except Exception as e:
|
|
504
|
+
print(f"Error: {e}")
|
|
505
|
+
|
|
506
|
+
if __name__ == "__main__":
|
|
507
|
+
# Initialize chatbot
|
|
508
|
+
checkpoint = "../models/checkpoints/best_model.pt"
|
|
509
|
+
tokenizer = "../tokenizer/conv_tokenizer.model"
|
|
510
|
+
|
|
511
|
+
if not Path(checkpoint).exists():
|
|
512
|
+
print(f"Error: Checkpoint not found at {checkpoint}")
|
|
513
|
+
print("Please train your model first using main.py")
|
|
10
514
|
sys.exit(1)
|
|
515
|
+
|
|
516
|
+
chatbot = ChatBot(checkpoint, tokenizer)
|
|
517
|
+
chatbot.chat()
|
|
518
|
+
""")
|
|
519
|
+
|
|
520
|
+
# Create batch inference script
|
|
521
|
+
with open(os.path.join(project_name, "inference", "batch_inference.py"), "w") as f:
|
|
522
|
+
f.write("""\"\"\"
|
|
523
|
+
Batch inference for conversational model
|
|
524
|
+
\"\"\"
|
|
525
|
+
import sys
|
|
526
|
+
sys.path.append('..')
|
|
527
|
+
|
|
528
|
+
import json
|
|
529
|
+
import gptmed
|
|
530
|
+
from pathlib import Path
|
|
531
|
+
from utils.conversation_handler import ConversationHistory
|
|
532
|
+
from typing import List, Dict
|
|
533
|
+
|
|
534
|
+
def generate_conversation_response(
|
|
535
|
+
conversation_history: List[Dict[str, str]],
|
|
536
|
+
checkpoint_path: str,
|
|
537
|
+
tokenizer_path: str,
|
|
538
|
+
max_length: int = 150
|
|
539
|
+
) -> str:
|
|
540
|
+
\"\"\"
|
|
541
|
+
Generate response for a conversation.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
conversation_history: List of messages with role and content
|
|
545
|
+
checkpoint_path: Path to model checkpoint
|
|
546
|
+
tokenizer_path: Path to tokenizer
|
|
547
|
+
max_length: Maximum response length
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
Generated response
|
|
551
|
+
\"\"\"
|
|
552
|
+
history = ConversationHistory()
|
|
553
|
+
|
|
554
|
+
# Rebuild conversation history
|
|
555
|
+
for msg in conversation_history:
|
|
556
|
+
if msg['role'] == 'user':
|
|
557
|
+
history.add_user_message(msg['content'])
|
|
558
|
+
elif msg['role'] == 'assistant':
|
|
559
|
+
history.add_assistant_message(msg['content'])
|
|
560
|
+
|
|
561
|
+
# Get prompt
|
|
562
|
+
prompt = history.get_prompt()
|
|
563
|
+
|
|
564
|
+
# Generate response
|
|
565
|
+
response = gptmed.generate(
|
|
566
|
+
checkpoint=checkpoint_path,
|
|
567
|
+
prompt=prompt,
|
|
568
|
+
tokenizer=tokenizer_path,
|
|
569
|
+
max_length=max_length,
|
|
570
|
+
temperature=0.8,
|
|
571
|
+
top_k=50
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Clean up response
|
|
575
|
+
response = response.replace(prompt, '').strip()
|
|
576
|
+
if "<|endoftext|>" in response:
|
|
577
|
+
response = response.split("<|endoftext|>")[0].strip()
|
|
578
|
+
|
|
579
|
+
return response
|
|
580
|
+
|
|
581
|
+
def batch_process_conversations(
|
|
582
|
+
input_file: str,
|
|
583
|
+
output_file: str,
|
|
584
|
+
checkpoint_path: str,
|
|
585
|
+
tokenizer_path: str
|
|
586
|
+
):
|
|
587
|
+
\"\"\"
|
|
588
|
+
Process multiple conversations in batch.
|
|
589
|
+
|
|
590
|
+
Input format (JSONL):
|
|
591
|
+
{"conversation": [{"role": "user", "content": "Hi"}, ...]}
|
|
592
|
+
|
|
593
|
+
Output format (JSONL):
|
|
594
|
+
{"conversation": [...], "generated_response": "..."}
|
|
595
|
+
\"\"\"
|
|
596
|
+
results = []
|
|
597
|
+
|
|
598
|
+
with open(input_file, 'r') as f:
|
|
599
|
+
for line in f:
|
|
600
|
+
item = json.loads(line)
|
|
601
|
+
conversation = item.get('conversation', [])
|
|
602
|
+
|
|
603
|
+
if conversation:
|
|
604
|
+
response = generate_conversation_response(
|
|
605
|
+
conversation,
|
|
606
|
+
checkpoint_path,
|
|
607
|
+
tokenizer_path
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
results.append({
|
|
611
|
+
"conversation": conversation,
|
|
612
|
+
"generated_response": response
|
|
613
|
+
})
|
|
614
|
+
|
|
615
|
+
# Save results
|
|
616
|
+
with open(output_file, 'w') as f:
|
|
617
|
+
for result in results:
|
|
618
|
+
f.write(json.dumps(result) + '\\n')
|
|
619
|
+
|
|
620
|
+
print(f"Processed {len(results)} conversations")
|
|
621
|
+
print(f"Results saved to {output_file}")
|
|
622
|
+
|
|
623
|
+
if __name__ == "__main__":
|
|
624
|
+
batch_process_conversations(
|
|
625
|
+
input_file='test_conversations.jsonl',
|
|
626
|
+
output_file='results.jsonl',
|
|
627
|
+
checkpoint_path='../models/checkpoints/best_model.pt',
|
|
628
|
+
tokenizer_path='../tokenizer/conv_tokenizer.model'
|
|
629
|
+
)
|
|
630
|
+
""")
|
|
631
|
+
|
|
632
|
+
# Create README.md
|
|
633
|
+
with open(os.path.join(project_name, "README.md"), "w") as f:
|
|
634
|
+
f.write(f"""# {project_name} - Conversational Model
|
|
635
|
+
|
|
636
|
+
Multi-turn conversational language model training architecture.
|
|
637
|
+
|
|
638
|
+
## Directory Structure
|
|
639
|
+
|
|
640
|
+
```
|
|
641
|
+
{project_name}/
|
|
642
|
+
├── configs/ # Training configurations
|
|
643
|
+
├── data/
|
|
644
|
+
│ ├── raw/ # Raw conversation data
|
|
645
|
+
│ └── processed/ # Preprocessed data
|
|
646
|
+
├── models/
|
|
647
|
+
│ └── checkpoints/ # Model checkpoints
|
|
648
|
+
├── tokenizer/ # Tokenizer files
|
|
649
|
+
├── logs/ # Training logs
|
|
650
|
+
├── inference/ # Inference scripts
|
|
651
|
+
│ ├── interactive_chat.py # Interactive chat interface
|
|
652
|
+
│ └── batch_inference.py # Batch processing
|
|
653
|
+
├── utils/ # Utility modules
|
|
654
|
+
│ └── conversation_handler.py
|
|
655
|
+
├── main.py # Main training script
|
|
656
|
+
├── preprocess.py # Data preprocessing
|
|
657
|
+
└── README.md
|
|
658
|
+
```
|
|
659
|
+
|
|
660
|
+
## Getting Started
|
|
661
|
+
|
|
662
|
+
### 1. Prepare Data
|
|
663
|
+
|
|
664
|
+
Place your conversational data in JSONL format in `data/raw/conversations.jsonl`:
|
|
665
|
+
|
|
666
|
+
```json
|
|
667
|
+
{{
|
|
668
|
+
"conversation": [
|
|
669
|
+
{{"role": "user", "content": "Hello!"}},
|
|
670
|
+
{{"role": "assistant", "content": "Hi! How can I help you today?"}},
|
|
671
|
+
{{"role": "user", "content": "Tell me about AI"}},
|
|
672
|
+
{{"role": "assistant", "content": "AI stands for Artificial Intelligence..."}}
|
|
673
|
+
]
|
|
674
|
+
}}
|
|
675
|
+
```
|
|
676
|
+
|
|
677
|
+
### 2. Preprocess Data
|
|
678
|
+
|
|
679
|
+
```bash
|
|
680
|
+
python preprocess.py
|
|
681
|
+
```
|
|
682
|
+
|
|
683
|
+
This will format conversations with special tokens:
|
|
684
|
+
- `<|user|>` - marks user messages
|
|
685
|
+
- `<|assistant|>` - marks assistant messages
|
|
686
|
+
- `<|endoftext|>` - marks end of conversation
|
|
687
|
+
|
|
688
|
+
### 3. Configure Training
|
|
689
|
+
|
|
690
|
+
```bash
|
|
691
|
+
python main.py # Creates default config
|
|
692
|
+
```
|
|
693
|
+
|
|
694
|
+
Edit `configs/training_config.yaml` with your settings.
|
|
695
|
+
|
|
696
|
+
### 4. Train Model
|
|
697
|
+
|
|
698
|
+
```bash
|
|
699
|
+
python main.py
|
|
700
|
+
```
|
|
701
|
+
|
|
702
|
+
### 5. Interactive Chat
|
|
703
|
+
|
|
704
|
+
```bash
|
|
705
|
+
cd inference
|
|
706
|
+
python interactive_chat.py
|
|
707
|
+
```
|
|
708
|
+
|
|
709
|
+
Commands:
|
|
710
|
+
- Type your message and press Enter
|
|
711
|
+
- Type `clear` to reset conversation history
|
|
712
|
+
- Type `quit` or `exit` to end chat
|
|
713
|
+
|
|
714
|
+
### 6. Batch Inference
|
|
715
|
+
|
|
716
|
+
```bash
|
|
717
|
+
cd inference
|
|
718
|
+
python batch_inference.py
|
|
719
|
+
```
|
|
720
|
+
|
|
721
|
+
## Features
|
|
722
|
+
|
|
723
|
+
### Multi-turn Conversations
|
|
724
|
+
The model maintains conversation history and generates contextually relevant responses.
|
|
725
|
+
|
|
726
|
+
### Conversation Management
|
|
727
|
+
The `ConversationHistory` class manages:
|
|
728
|
+
- Message history tracking
|
|
729
|
+
- Automatic history trimming
|
|
730
|
+
- Prompt generation from history
|
|
731
|
+
|
|
732
|
+
### Interactive Chat
|
|
733
|
+
Real-time chat interface with:
|
|
734
|
+
- Multi-turn conversation support
|
|
735
|
+
- History management
|
|
736
|
+
- User-friendly commands
|
|
737
|
+
|
|
738
|
+
### Batch Processing
|
|
739
|
+
Process multiple conversations:
|
|
740
|
+
- Evaluate model on test sets
|
|
741
|
+
- Generate responses for datasets
|
|
742
|
+
- Performance benchmarking
|
|
743
|
+
|
|
744
|
+
## Conversation Format
|
|
745
|
+
|
|
746
|
+
The model uses special tokens to structure conversations:
|
|
747
|
+
|
|
748
|
+
```
|
|
749
|
+
<|user|> Hello, how are you?
|
|
750
|
+
<|assistant|> I'm doing well, thank you! How can I assist you today?
|
|
751
|
+
<|user|> I need help with Python
|
|
752
|
+
<|assistant|> I'd be happy to help with Python! What specific topic?
|
|
753
|
+
<|endoftext|>
|
|
754
|
+
```
|
|
755
|
+
|
|
756
|
+
## Training Tips
|
|
757
|
+
|
|
758
|
+
1. **Data Quality**: Ensure conversations are natural and coherent
|
|
759
|
+
2. **History Length**: Adjust `max_history` based on your use case
|
|
760
|
+
3. **Temperature**: Lower (0.6-0.8) for focused responses, higher (0.8-1.0) for creative
|
|
761
|
+
4. **Model Size**: Start with tiny/small, scale up as needed
|
|
762
|
+
|
|
763
|
+
## Inference Parameters
|
|
764
|
+
|
|
765
|
+
Adjust generation parameters in inference scripts:
|
|
766
|
+
|
|
767
|
+
```python
|
|
768
|
+
response = gptmed.generate(
|
|
769
|
+
checkpoint=checkpoint_path,
|
|
770
|
+
prompt=prompt,
|
|
771
|
+
tokenizer=tokenizer_path,
|
|
772
|
+
max_length=150, # Maximum response length
|
|
773
|
+
temperature=0.8, # Randomness (0.0-1.0)
|
|
774
|
+
top_k=50, # Top-k sampling
|
|
775
|
+
top_p=0.9 # Nucleus sampling
|
|
776
|
+
)
|
|
777
|
+
```
|
|
778
|
+
|
|
779
|
+
## Example Usage
|
|
780
|
+
|
|
781
|
+
```python
|
|
782
|
+
from utils.conversation_handler import ConversationHistory
|
|
783
|
+
import gptmed
|
|
784
|
+
|
|
785
|
+
# Initialize conversation
|
|
786
|
+
history = ConversationHistory(max_history=5)
|
|
787
|
+
|
|
788
|
+
# Add messages
|
|
789
|
+
history.add_user_message("What is machine learning?")
|
|
790
|
+
|
|
791
|
+
# Generate prompt
|
|
792
|
+
prompt = history.get_prompt()
|
|
793
|
+
|
|
794
|
+
# Generate response
|
|
795
|
+
response = gptmed.generate(
|
|
796
|
+
checkpoint='models/checkpoints/best_model.pt',
|
|
797
|
+
prompt=prompt,
|
|
798
|
+
tokenizer='tokenizer/conv_tokenizer.model'
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# Add response to history
|
|
802
|
+
history.add_assistant_message(response)
|
|
803
|
+
```
|
|
804
|
+
""")
|
|
805
|
+
|
|
806
|
+
def create_basic_project(project_name):
|
|
807
|
+
"""Create basic project structure (original behavior)"""
|
|
11
808
|
os.makedirs(os.path.join(project_name, "configs"))
|
|
12
809
|
os.makedirs(os.path.join(project_name, "tasks"))
|
|
13
810
|
os.makedirs(os.path.join(project_name, "models"))
|
|
14
811
|
os.makedirs(os.path.join(project_name, "data"))
|
|
15
812
|
with open(os.path.join(project_name, "main.py"), "w") as f:
|
|
16
813
|
f.write("import gptmed\n\n# Your project entrypoint\n")
|
|
17
|
-
|
|
814
|
+
|
|
815
|
+
def startproject(project_name, project_type=None):
|
|
816
|
+
"""
|
|
817
|
+
Create a new gptmed project.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
project_name: Name of the project
|
|
821
|
+
project_type: Type of project ('qna', 'conversational', or None for basic)
|
|
822
|
+
"""
|
|
823
|
+
if not project_name.isidentifier():
|
|
824
|
+
print("Invalid project name. Your project name must be a valid Python identifier. "
|
|
825
|
+
"Do not use hyphens or spaces. Use underscores instead.")
|
|
826
|
+
sys.exit(1)
|
|
827
|
+
|
|
828
|
+
if os.path.exists(project_name):
|
|
829
|
+
print(f"Directory '{project_name}' already exists.")
|
|
830
|
+
sys.exit(1)
|
|
831
|
+
|
|
832
|
+
# Create project based on type
|
|
833
|
+
if project_type == "qna":
|
|
834
|
+
create_qna_templates(project_name)
|
|
835
|
+
print(f"QNA project '{project_name}' created successfully!")
|
|
836
|
+
print(f"\nNext steps:")
|
|
837
|
+
print(f"1. cd {project_name}")
|
|
838
|
+
print(f"2. Place your QNA data in data/raw/qna_data.jsonl")
|
|
839
|
+
print(f"3. Run: python preprocess.py")
|
|
840
|
+
print(f"4. Run: python main.py")
|
|
841
|
+
|
|
842
|
+
elif project_type == "conversational":
|
|
843
|
+
create_conversational_templates(project_name)
|
|
844
|
+
print(f"Conversational project '{project_name}' created successfully!")
|
|
845
|
+
print(f"\nNext steps:")
|
|
846
|
+
print(f"1. cd {project_name}")
|
|
847
|
+
print(f"2. Place your conversation data in data/raw/conversations.jsonl")
|
|
848
|
+
print(f"3. Run: python preprocess.py")
|
|
849
|
+
print(f"4. Run: python main.py")
|
|
850
|
+
print(f"5. For interactive chat: cd inference && python interactive_chat.py")
|
|
851
|
+
|
|
852
|
+
else:
|
|
853
|
+
create_basic_project(project_name)
|
|
854
|
+
print(f"Project '{project_name}' created.")
|
|
855
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gptmed
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
|
|
5
5
|
Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>, Sanjog Sigdel <sanjog.sigdel@ku.edu.np>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -10,8 +10,9 @@ gptmed/data/parsers/__init__.py,sha256=BgVzXuZgeE5DUCC4SzN7vflL40wQ4Q4_4DmJ1Y43_
|
|
|
10
10
|
gptmed/data/parsers/medquad_parser.py,sha256=g3QCRiVBdcq8RdyuYH_qKFrHgU5KkHY59WfWxUwspP0,7974
|
|
11
11
|
gptmed/data/parsers/text_formatter.py,sha256=tVmnDBT54BbxX9BPKMXSPzzLmM39frDxKRKuz_HoRag,4072
|
|
12
12
|
gptmed/framework/__init__.py,sha256=TlzM7NS_n0KQnm9PQTJRrb5pEb6rBXC1pqGPhbSO_bQ,25
|
|
13
|
-
gptmed/framework/cli/__init__.py,sha256=
|
|
14
|
-
gptmed/framework/cli/
|
|
13
|
+
gptmed/framework/cli/__init__.py,sha256=oBUmoaWLCjFs3_aod-hcMCcC11UP4t4SohDnZ7Sdmx0,729
|
|
14
|
+
gptmed/framework/cli/__main__.py,sha256=rLBZjEi695ZgOW8pypqpg2kLgtcDhrI_9_QcrUO3WkU,103
|
|
15
|
+
gptmed/framework/cli/startproject.py,sha256=l73Isqtp5MTMxPBfsy4sIpVp2ClA9p9_cRG7Lg7QERY,25489
|
|
15
16
|
gptmed/inference/__init__.py,sha256=NDPViXhOgpItC8n13T9axX4UH1E7mrjt6kJ5OfIwvMs,25
|
|
16
17
|
gptmed/inference/decoding_utils.py,sha256=zTDZYdl2jcGwSrcINXMw-5uoYuF4A9TSushhPxJi1o0,5041
|
|
17
18
|
gptmed/inference/generation_config.py,sha256=hpPyZUk1K6qGSBAoQx3Jm0_ZrrYld77ACxbIlCCCcVU,2813
|
|
@@ -45,9 +46,9 @@ gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
|
|
|
45
46
|
gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
|
|
46
47
|
gptmed/utils/checkpoints.py,sha256=jPKJtO0YRZieGmpwqotgDkBzd__s_raDxS1kLpfjBJE,7113
|
|
47
48
|
gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
|
|
48
|
-
gptmed-0.5.
|
|
49
|
-
gptmed-0.5.
|
|
50
|
-
gptmed-0.5.
|
|
51
|
-
gptmed-0.5.
|
|
52
|
-
gptmed-0.5.
|
|
53
|
-
gptmed-0.5.
|
|
49
|
+
gptmed-0.5.5.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
|
|
50
|
+
gptmed-0.5.5.dist-info/METADATA,sha256=-7Drfeaxy2SWQ3nlnHqwP4pj1q9bsrmTxe99pIWMePk,13842
|
|
51
|
+
gptmed-0.5.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
52
|
+
gptmed-0.5.5.dist-info/entry_points.txt,sha256=ZZeYg2kOQuHHvRvQYRvq5L-RpClnBMHSpUom9DxQW0c,145
|
|
53
|
+
gptmed-0.5.5.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
|
|
54
|
+
gptmed-0.5.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|