wolfhece 2.2.43__py3-none-any.whl → 2.2.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wolfhece/ChatwWOLF.py +200 -0
- wolfhece/Model1D.py +1785 -11
- wolfhece/PyCrosssections.py +1536 -699
- wolfhece/PyDraw.py +61 -8
- wolfhece/PyVertexvectors.py +64 -22
- wolfhece/RatingCurve_xml.py +15 -1
- wolfhece/analyze_poly.py +198 -4
- wolfhece/apps/version.py +1 -1
- wolfhece/dike.py +265 -19
- wolfhece/eikonal.py +1 -0
- wolfhece/wolf_array.py +242 -29
- {wolfhece-2.2.43.dist-info → wolfhece-2.2.44.dist-info}/METADATA +2 -2
- {wolfhece-2.2.43.dist-info → wolfhece-2.2.44.dist-info}/RECORD +16 -15
- {wolfhece-2.2.43.dist-info → wolfhece-2.2.44.dist-info}/WHEEL +0 -0
- {wolfhece-2.2.43.dist-info → wolfhece-2.2.44.dist-info}/entry_points.txt +0 -0
- {wolfhece-2.2.43.dist-info → wolfhece-2.2.44.dist-info}/top_level.txt +0 -0
wolfhece/ChatwWOLF.py
ADDED
@@ -0,0 +1,200 @@
|
|
1
|
+
# Préparation des données pour le modèle ChatWOLF, machine conversationnelle spécialisée dans les questions relatives à WOLF.
|
2
|
+
# Les données sont principalemen extraites des fichiers rst de l'aide en ligne mais également des fichiers py de l'API.
|
3
|
+
import torch
|
4
|
+
print(torch.cuda.is_available())
|
5
|
+
|
6
|
+
# Importation des modules nécessaires
|
7
|
+
import os
|
8
|
+
import re
|
9
|
+
import json
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import List, Dict, Any
|
12
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13
|
+
from langchain.docstore.document import Document
|
14
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
15
|
+
from langchain_community.vectorstores import FAISS
|
16
|
+
from langchain.chains.retrieval_qa.base import RetrievalQA
|
17
|
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
18
|
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
19
|
+
from sklearn.model_selection import train_test_split
|
20
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
21
|
+
import torch
|
22
|
+
from datasets import load_dataset, Dataset
|
23
|
+
from transformers import Trainer, TrainingArguments
|
24
|
+
import logging
|
25
|
+
|
26
|
+
logging.basicConfig(level=logging.INFO)
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
rst_directory = Path("D:/ProgrammationGitLab/HECEPython/docs/source")
|
30
|
+
py_directory = Path("D:/ProgrammationGitLab/HECEPython/wolfhece")
|
31
|
+
output_directory = Path("D:/ProgrammationGitLab/HECEPython/wolfhece/models/chatwolf")
|
32
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
33
|
+
|
34
|
+
# Fonction pour extraire le texte des fichiers rst
|
35
|
+
def extract_text_from_rst(file_path: Path) -> str:
|
36
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
37
|
+
text = file.read()
|
38
|
+
# Nettoyage du texte
|
39
|
+
text = re.sub(r'\.\. _.*?:', '', text) # Remove references
|
40
|
+
text = re.sub(r'\.\. note::.*?\n\n', '', text, flags=re.DOTALL) # Remove notes
|
41
|
+
text = re.sub(r'\.\. warning::.*?\n\n', '', text, flags=re.DOTALL) # Remove warnings
|
42
|
+
text = re.sub(r'\.\. code-block::.*?\n\n', '', text, flags=re.DOTALL) # Remove code blocks
|
43
|
+
text = re.sub(r'\.\. image::.*?\n\n', '', text, flags=re.DOTALL) # Remove images
|
44
|
+
text = re.sub(r'\.\. figure::.*?\n\n', '', text, flags=re.DOTALL) # Remove figures
|
45
|
+
text = re.sub(r'\.\. table::.*?\n\n', '', text, flags=re.DOTALL) # Remove tables
|
46
|
+
text = re.sub(r'\.\. rubric::.*?\n\n', '', text, flags=re.DOTALL) # Remove rubrics
|
47
|
+
text = re.sub(r'\.\. sidebar::.*?\n\n', '', text, flags=re.DOTALL) # Remove sidebars
|
48
|
+
text = re.sub(r'\.\. literalinclude::.*?\n\n', '', text, flags=re.DOTALL) # Remove literal includes
|
49
|
+
text = re.sub(r'\.\. math::.*?\n\n', '', text, flags=re.DOTALL) # Remove math
|
50
|
+
text = re.sub(r'\.\. raw::.*?\n\n', '', text, flags=re.DOTALL) # Remove raw
|
51
|
+
text = re.sub(r'\.\. toctree::.*?\n\n', '', text, flags=re.DOTALL) # Remove toctree
|
52
|
+
text = re.sub(r'\.\. index::.*?\n\n', '', text, flags=re.DOTALL) # Remove index
|
53
|
+
text = re.sub(r'\.\. glossary::.*?\n\n', '', text, flags=re.DOTALL) # Remove glossary
|
54
|
+
text = re.sub(r'\.\. footnote::.*?\n\n', '', text, flags=re.DOTALL) # Remove footnotes
|
55
|
+
text = re.sub(r'\.\. citation::.*?\n\n', '', text, flags=re.DOTALL) # Remove citations
|
56
|
+
text = re.sub(r'\.\. epigraph::.*?\n\n', '', text, flags=re.DOTALL) # Remove epigraphs
|
57
|
+
text = re.sub(r'\.\. highlight::.*?\n\n', '', text, flags=re.DOTALL) # Remove highlights
|
58
|
+
text = re.sub(r'\.\. hlist::.*?\n\n', '', text, flags=re.DOTALL) # Remove hlists
|
59
|
+
text = re.sub(r'\.\. csv-table::.*?\n\n', '', text, flags=re.DOTALL) # Remove csv-tables
|
60
|
+
text = re.sub(r'\.\. list-table::.*?\n\n', '', text, flags=re.DOTALL) # Remove list-tables
|
61
|
+
text = re.sub(r'\.\. contents::.*?\n\n', '', text, flags=re.DOTALL) # Remove contents
|
62
|
+
text = re.sub(r'\.\. include::.*?\n\n', '', text, flags=re.DOTALL) # Remove includes
|
63
|
+
text = re.sub(r'\.\. admonition::.*?\n\n', '', text, flags=re.DOTALL) # Remove admonitions
|
64
|
+
text = re.sub(r'\.\. note::.*?\n\n', '', text, flags=re.DOTALL) # Remove notes
|
65
|
+
text = re.sub(r'\.\. tip::.*?\n\n', '', text, flags=re.DOTALL) # Remove tips
|
66
|
+
text = re.sub(r'\.\. important::.*?\n\n', '', text, flags=re.DOTALL) # Remove importants
|
67
|
+
text = re.sub(r'\.\. caution::.*?\n\n', '', text, flags=re.DOTALL) # Remove cautions
|
68
|
+
text = re.sub(r'\.\. seealso::.*?\n\n', '', text, flags=re.DOTALL) # Remove seealso
|
69
|
+
|
70
|
+
return text
|
71
|
+
|
72
|
+
def scan_files() -> List[Path]:
|
73
|
+
# Scan all files and extract text
|
74
|
+
documents = []
|
75
|
+
for rst_file in rst_directory.rglob("*.rst"):
|
76
|
+
text = extract_text_from_rst(rst_file)
|
77
|
+
if text.strip(): # Only add non-empty documents
|
78
|
+
documents.append(Document(page_content=text, metadata={"source": str(rst_file)}))
|
79
|
+
logger.info(f"Extracted text from {rst_file}")
|
80
|
+
for py_file in py_directory.rglob("*.py"):
|
81
|
+
with open(py_file, 'r', encoding='utf-8') as file:
|
82
|
+
text = file.read()
|
83
|
+
if text.strip(): # Only add non-empty documents
|
84
|
+
documents.append(Document(page_content=text, metadata={"source": str(py_file)}))
|
85
|
+
logger.info(f"Extracted text from {py_file}")
|
86
|
+
logger.info(f"Total documents extracted: {len(documents)}")
|
87
|
+
return documents
|
88
|
+
|
89
|
+
def split_and_prepare_data(documents: List[Document]) -> None:
|
90
|
+
# Split documents into smaller chunks
|
91
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
92
|
+
chunk_size=1000,
|
93
|
+
chunk_overlap=100,
|
94
|
+
length_function=len
|
95
|
+
)
|
96
|
+
texts = text_splitter.split_documents(documents)
|
97
|
+
logger.info(f"Total text chunks created: {len(texts)}")
|
98
|
+
# Save texts to JSONL for dataset creation
|
99
|
+
jsonl_path = output_directory / "chatwolf_data.jsonl"
|
100
|
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
101
|
+
for text in texts:
|
102
|
+
json.dump({"text": text.page_content}, f)
|
103
|
+
f.write('\n')
|
104
|
+
logger.info(f"Saved text chunks to {jsonl_path}")
|
105
|
+
return texts, jsonl_path
|
106
|
+
|
107
|
+
def train_model():
|
108
|
+
# Load dataset
|
109
|
+
dataset = load_dataset('json', data_files=str(jsonl_path))['train']
|
110
|
+
# Split dataset into training and validation sets
|
111
|
+
train_test_split = dataset.train_test_split(test_size=0.1)
|
112
|
+
train_dataset = train_test_split['train']
|
113
|
+
eval_dataset = train_test_split['test']
|
114
|
+
logger.info(f"Training dataset size: {len(train_dataset)}")
|
115
|
+
logger.info(f"Validation dataset size: {len(eval_dataset)}")
|
116
|
+
# Define model and tokenizer
|
117
|
+
model_name = "gpt2"
|
118
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
119
|
+
if tokenizer.pad_token is None:
|
120
|
+
tokenizer.pad_token = tokenizer.eos_token
|
121
|
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
122
|
+
# Define training arguments
|
123
|
+
training_args = TrainingArguments(
|
124
|
+
output_dir=output_directory / "output",
|
125
|
+
eval_strategy="epoch",
|
126
|
+
num_train_epochs=3,
|
127
|
+
per_device_train_batch_size=1,
|
128
|
+
per_device_eval_batch_size=1,
|
129
|
+
save_strategy="epoch",
|
130
|
+
logging_dir=output_directory / "logs",
|
131
|
+
logging_steps=10,
|
132
|
+
save_total_limit=2,
|
133
|
+
fp16=False, # Set to False to avoid FP16 errors on unsupported hardware
|
134
|
+
load_best_model_at_end=True,
|
135
|
+
)
|
136
|
+
# Define data collator
|
137
|
+
def tokenize_function(examples):
|
138
|
+
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
|
139
|
+
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
140
|
+
eval_dataset = eval_dataset.map(tokenize_function, batched=True)
|
141
|
+
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
|
142
|
+
eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
|
143
|
+
# Define data collator for causal language modeling
|
144
|
+
from transformers import DataCollatorForLanguageModeling
|
145
|
+
data_collator = DataCollatorForLanguageModeling(
|
146
|
+
tokenizer=tokenizer,
|
147
|
+
mlm=False,
|
148
|
+
)
|
149
|
+
# Initialize Trainer
|
150
|
+
trainer = Trainer(
|
151
|
+
model=model,
|
152
|
+
args=training_args,
|
153
|
+
train_dataset=train_dataset,
|
154
|
+
eval_dataset=eval_dataset,
|
155
|
+
tokenizer=tokenizer,
|
156
|
+
data_collator=data_collator,
|
157
|
+
)
|
158
|
+
# Train the model
|
159
|
+
trainer.train()
|
160
|
+
# Save the fine-tuned model
|
161
|
+
trainer.save_model(output_directory / "chatwolf_model")
|
162
|
+
logger.info(f"Saved fine-tuned model to {output_directory / 'chatwolf_model'}")
|
163
|
+
return model, tokenizer
|
164
|
+
|
165
|
+
def load_model_and_tokenizer():
|
166
|
+
model = AutoModelForCausalLM.from_pretrained(output_directory / "chatwolf_model")
|
167
|
+
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
|
168
|
+
if tokenizer.pad_token is None:
|
169
|
+
tokenizer.pad_token = tokenizer.eos_token
|
170
|
+
return model, tokenizer
|
171
|
+
|
172
|
+
documents = scan_files()
|
173
|
+
texts, jsonl_path = split_and_prepare_data(documents)
|
174
|
+
|
175
|
+
if False:
|
176
|
+
model, tokenizer = train_model()
|
177
|
+
else:
|
178
|
+
model, tokenizer = load_model_and_tokenizer()
|
179
|
+
|
180
|
+
|
181
|
+
# Create embeddings and vector store
|
182
|
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
183
|
+
vector_store = FAISS.from_documents(texts, embeddings)
|
184
|
+
vector_store.save_local(str(output_directory / "faiss_index"))
|
185
|
+
logger.info(f"Saved FAISS index to {output_directory / 'faiss_index'}")
|
186
|
+
# Create retrieval QA chain
|
187
|
+
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.7, top_p=0.9, repetition_penalty=1.2)
|
188
|
+
hf_llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
189
|
+
qa_chain = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff", retriever=vector_store.as_retriever())
|
190
|
+
# Save the QA chain
|
191
|
+
import pickle
|
192
|
+
with open(output_directory / "qa_chain.pkl", 'wb') as f:
|
193
|
+
pickle.dump(qa_chain, f)
|
194
|
+
logger.info(f"Saved QA chain to {output_directory / 'qa_chain.pkl'}")
|
195
|
+
# Example usage of the QA chain
|
196
|
+
def answer_question(question: str) -> str:
|
197
|
+
return qa_chain.run(question)
|
198
|
+
example_question = "How to create a new map in WOLF?"
|
199
|
+
answer = answer_question(example_question)
|
200
|
+
logger.info(f"Question: {example_question}\nAnswer: {answer}")
|