wolfhece 2.2.43__py3-none-any.whl → 2.2.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wolfhece/ChatwWOLF.py ADDED
@@ -0,0 +1,200 @@
1
+ # Préparation des données pour le modèle ChatWOLF, machine conversationnelle spécialisée dans les questions relatives à WOLF.
2
+ # Les données sont principalemen extraites des fichiers rst de l'aide en ligne mais également des fichiers py de l'API.
3
+ import torch
4
+ print(torch.cuda.is_available())
5
+
6
+ # Importation des modules nécessaires
7
+ import os
8
+ import re
9
+ import json
10
+ from pathlib import Path
11
+ from typing import List, Dict, Any
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.docstore.document import Document
14
+ from langchain_huggingface import HuggingFaceEmbeddings
15
+ from langchain_community.vectorstores import FAISS
16
+ from langchain.chains.retrieval_qa.base import RetrievalQA
17
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
18
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
19
+ from sklearn.model_selection import train_test_split
20
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
21
+ import torch
22
+ from datasets import load_dataset, Dataset
23
+ from transformers import Trainer, TrainingArguments
24
+ import logging
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ rst_directory = Path("D:/ProgrammationGitLab/HECEPython/docs/source")
30
+ py_directory = Path("D:/ProgrammationGitLab/HECEPython/wolfhece")
31
+ output_directory = Path("D:/ProgrammationGitLab/HECEPython/wolfhece/models/chatwolf")
32
+ output_directory.mkdir(parents=True, exist_ok=True)
33
+
34
+ # Fonction pour extraire le texte des fichiers rst
35
+ def extract_text_from_rst(file_path: Path) -> str:
36
+ with open(file_path, 'r', encoding='utf-8') as file:
37
+ text = file.read()
38
+ # Nettoyage du texte
39
+ text = re.sub(r'\.\. _.*?:', '', text) # Remove references
40
+ text = re.sub(r'\.\. note::.*?\n\n', '', text, flags=re.DOTALL) # Remove notes
41
+ text = re.sub(r'\.\. warning::.*?\n\n', '', text, flags=re.DOTALL) # Remove warnings
42
+ text = re.sub(r'\.\. code-block::.*?\n\n', '', text, flags=re.DOTALL) # Remove code blocks
43
+ text = re.sub(r'\.\. image::.*?\n\n', '', text, flags=re.DOTALL) # Remove images
44
+ text = re.sub(r'\.\. figure::.*?\n\n', '', text, flags=re.DOTALL) # Remove figures
45
+ text = re.sub(r'\.\. table::.*?\n\n', '', text, flags=re.DOTALL) # Remove tables
46
+ text = re.sub(r'\.\. rubric::.*?\n\n', '', text, flags=re.DOTALL) # Remove rubrics
47
+ text = re.sub(r'\.\. sidebar::.*?\n\n', '', text, flags=re.DOTALL) # Remove sidebars
48
+ text = re.sub(r'\.\. literalinclude::.*?\n\n', '', text, flags=re.DOTALL) # Remove literal includes
49
+ text = re.sub(r'\.\. math::.*?\n\n', '', text, flags=re.DOTALL) # Remove math
50
+ text = re.sub(r'\.\. raw::.*?\n\n', '', text, flags=re.DOTALL) # Remove raw
51
+ text = re.sub(r'\.\. toctree::.*?\n\n', '', text, flags=re.DOTALL) # Remove toctree
52
+ text = re.sub(r'\.\. index::.*?\n\n', '', text, flags=re.DOTALL) # Remove index
53
+ text = re.sub(r'\.\. glossary::.*?\n\n', '', text, flags=re.DOTALL) # Remove glossary
54
+ text = re.sub(r'\.\. footnote::.*?\n\n', '', text, flags=re.DOTALL) # Remove footnotes
55
+ text = re.sub(r'\.\. citation::.*?\n\n', '', text, flags=re.DOTALL) # Remove citations
56
+ text = re.sub(r'\.\. epigraph::.*?\n\n', '', text, flags=re.DOTALL) # Remove epigraphs
57
+ text = re.sub(r'\.\. highlight::.*?\n\n', '', text, flags=re.DOTALL) # Remove highlights
58
+ text = re.sub(r'\.\. hlist::.*?\n\n', '', text, flags=re.DOTALL) # Remove hlists
59
+ text = re.sub(r'\.\. csv-table::.*?\n\n', '', text, flags=re.DOTALL) # Remove csv-tables
60
+ text = re.sub(r'\.\. list-table::.*?\n\n', '', text, flags=re.DOTALL) # Remove list-tables
61
+ text = re.sub(r'\.\. contents::.*?\n\n', '', text, flags=re.DOTALL) # Remove contents
62
+ text = re.sub(r'\.\. include::.*?\n\n', '', text, flags=re.DOTALL) # Remove includes
63
+ text = re.sub(r'\.\. admonition::.*?\n\n', '', text, flags=re.DOTALL) # Remove admonitions
64
+ text = re.sub(r'\.\. note::.*?\n\n', '', text, flags=re.DOTALL) # Remove notes
65
+ text = re.sub(r'\.\. tip::.*?\n\n', '', text, flags=re.DOTALL) # Remove tips
66
+ text = re.sub(r'\.\. important::.*?\n\n', '', text, flags=re.DOTALL) # Remove importants
67
+ text = re.sub(r'\.\. caution::.*?\n\n', '', text, flags=re.DOTALL) # Remove cautions
68
+ text = re.sub(r'\.\. seealso::.*?\n\n', '', text, flags=re.DOTALL) # Remove seealso
69
+
70
+ return text
71
+
72
+ def scan_files() -> List[Path]:
73
+ # Scan all files and extract text
74
+ documents = []
75
+ for rst_file in rst_directory.rglob("*.rst"):
76
+ text = extract_text_from_rst(rst_file)
77
+ if text.strip(): # Only add non-empty documents
78
+ documents.append(Document(page_content=text, metadata={"source": str(rst_file)}))
79
+ logger.info(f"Extracted text from {rst_file}")
80
+ for py_file in py_directory.rglob("*.py"):
81
+ with open(py_file, 'r', encoding='utf-8') as file:
82
+ text = file.read()
83
+ if text.strip(): # Only add non-empty documents
84
+ documents.append(Document(page_content=text, metadata={"source": str(py_file)}))
85
+ logger.info(f"Extracted text from {py_file}")
86
+ logger.info(f"Total documents extracted: {len(documents)}")
87
+ return documents
88
+
89
+ def split_and_prepare_data(documents: List[Document]) -> None:
90
+ # Split documents into smaller chunks
91
+ text_splitter = RecursiveCharacterTextSplitter(
92
+ chunk_size=1000,
93
+ chunk_overlap=100,
94
+ length_function=len
95
+ )
96
+ texts = text_splitter.split_documents(documents)
97
+ logger.info(f"Total text chunks created: {len(texts)}")
98
+ # Save texts to JSONL for dataset creation
99
+ jsonl_path = output_directory / "chatwolf_data.jsonl"
100
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
101
+ for text in texts:
102
+ json.dump({"text": text.page_content}, f)
103
+ f.write('\n')
104
+ logger.info(f"Saved text chunks to {jsonl_path}")
105
+ return texts, jsonl_path
106
+
107
+ def train_model():
108
+ # Load dataset
109
+ dataset = load_dataset('json', data_files=str(jsonl_path))['train']
110
+ # Split dataset into training and validation sets
111
+ train_test_split = dataset.train_test_split(test_size=0.1)
112
+ train_dataset = train_test_split['train']
113
+ eval_dataset = train_test_split['test']
114
+ logger.info(f"Training dataset size: {len(train_dataset)}")
115
+ logger.info(f"Validation dataset size: {len(eval_dataset)}")
116
+ # Define model and tokenizer
117
+ model_name = "gpt2"
118
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
119
+ if tokenizer.pad_token is None:
120
+ tokenizer.pad_token = tokenizer.eos_token
121
+ model = AutoModelForCausalLM.from_pretrained(model_name)
122
+ # Define training arguments
123
+ training_args = TrainingArguments(
124
+ output_dir=output_directory / "output",
125
+ eval_strategy="epoch",
126
+ num_train_epochs=3,
127
+ per_device_train_batch_size=1,
128
+ per_device_eval_batch_size=1,
129
+ save_strategy="epoch",
130
+ logging_dir=output_directory / "logs",
131
+ logging_steps=10,
132
+ save_total_limit=2,
133
+ fp16=False, # Set to False to avoid FP16 errors on unsupported hardware
134
+ load_best_model_at_end=True,
135
+ )
136
+ # Define data collator
137
+ def tokenize_function(examples):
138
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
139
+ train_dataset = train_dataset.map(tokenize_function, batched=True)
140
+ eval_dataset = eval_dataset.map(tokenize_function, batched=True)
141
+ train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
142
+ eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
143
+ # Define data collator for causal language modeling
144
+ from transformers import DataCollatorForLanguageModeling
145
+ data_collator = DataCollatorForLanguageModeling(
146
+ tokenizer=tokenizer,
147
+ mlm=False,
148
+ )
149
+ # Initialize Trainer
150
+ trainer = Trainer(
151
+ model=model,
152
+ args=training_args,
153
+ train_dataset=train_dataset,
154
+ eval_dataset=eval_dataset,
155
+ tokenizer=tokenizer,
156
+ data_collator=data_collator,
157
+ )
158
+ # Train the model
159
+ trainer.train()
160
+ # Save the fine-tuned model
161
+ trainer.save_model(output_directory / "chatwolf_model")
162
+ logger.info(f"Saved fine-tuned model to {output_directory / 'chatwolf_model'}")
163
+ return model, tokenizer
164
+
165
+ def load_model_and_tokenizer():
166
+ model = AutoModelForCausalLM.from_pretrained(output_directory / "chatwolf_model")
167
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
168
+ if tokenizer.pad_token is None:
169
+ tokenizer.pad_token = tokenizer.eos_token
170
+ return model, tokenizer
171
+
172
+ documents = scan_files()
173
+ texts, jsonl_path = split_and_prepare_data(documents)
174
+
175
+ if False:
176
+ model, tokenizer = train_model()
177
+ else:
178
+ model, tokenizer = load_model_and_tokenizer()
179
+
180
+
181
+ # Create embeddings and vector store
182
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
183
+ vector_store = FAISS.from_documents(texts, embeddings)
184
+ vector_store.save_local(str(output_directory / "faiss_index"))
185
+ logger.info(f"Saved FAISS index to {output_directory / 'faiss_index'}")
186
+ # Create retrieval QA chain
187
+ llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.7, top_p=0.9, repetition_penalty=1.2)
188
+ hf_llm = HuggingFacePipeline(pipeline=llm_pipeline)
189
+ qa_chain = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff", retriever=vector_store.as_retriever())
190
+ # Save the QA chain
191
+ import pickle
192
+ with open(output_directory / "qa_chain.pkl", 'wb') as f:
193
+ pickle.dump(qa_chain, f)
194
+ logger.info(f"Saved QA chain to {output_directory / 'qa_chain.pkl'}")
195
+ # Example usage of the QA chain
196
+ def answer_question(question: str) -> str:
197
+ return qa_chain.run(question)
198
+ example_question = "How to create a new map in WOLF?"
199
+ answer = answer_question(example_question)
200
+ logger.info(f"Question: {example_question}\nAnswer: {answer}")