npcpy 1.0.26__py3-none-any.whl → 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/__init__.py +0 -7
- npcpy/data/audio.py +16 -99
- npcpy/data/image.py +43 -42
- npcpy/data/load.py +83 -124
- npcpy/data/text.py +28 -28
- npcpy/data/video.py +8 -32
- npcpy/data/web.py +51 -23
- npcpy/ft/diff.py +110 -0
- npcpy/ft/ge.py +115 -0
- npcpy/ft/memory_trainer.py +171 -0
- npcpy/ft/model_ensembler.py +357 -0
- npcpy/ft/rl.py +360 -0
- npcpy/ft/sft.py +248 -0
- npcpy/ft/usft.py +128 -0
- npcpy/gen/audio_gen.py +24 -0
- npcpy/gen/embeddings.py +13 -13
- npcpy/gen/image_gen.py +262 -117
- npcpy/gen/response.py +615 -415
- npcpy/gen/video_gen.py +53 -7
- npcpy/llm_funcs.py +1869 -437
- npcpy/main.py +1 -1
- npcpy/memory/command_history.py +844 -510
- npcpy/memory/kg_vis.py +833 -0
- npcpy/memory/knowledge_graph.py +892 -1845
- npcpy/memory/memory_processor.py +81 -0
- npcpy/memory/search.py +188 -90
- npcpy/mix/debate.py +192 -3
- npcpy/npc_compiler.py +1672 -801
- npcpy/npc_sysenv.py +593 -1266
- npcpy/serve.py +3120 -0
- npcpy/sql/ai_function_tools.py +257 -0
- npcpy/sql/database_ai_adapters.py +186 -0
- npcpy/sql/database_ai_functions.py +163 -0
- npcpy/sql/model_runner.py +19 -19
- npcpy/sql/npcsql.py +706 -507
- npcpy/sql/sql_model_compiler.py +156 -0
- npcpy/tools.py +183 -0
- npcpy/work/plan.py +13 -279
- npcpy/work/trigger.py +3 -3
- npcpy-1.2.32.dist-info/METADATA +803 -0
- npcpy-1.2.32.dist-info/RECORD +54 -0
- npcpy/data/dataframes.py +0 -171
- npcpy/memory/deep_research.py +0 -125
- npcpy/memory/sleep.py +0 -557
- npcpy/modes/_state.py +0 -78
- npcpy/modes/alicanto.py +0 -1075
- npcpy/modes/guac.py +0 -785
- npcpy/modes/mcp_npcsh.py +0 -822
- npcpy/modes/npc.py +0 -213
- npcpy/modes/npcsh.py +0 -1158
- npcpy/modes/plonk.py +0 -409
- npcpy/modes/pti.py +0 -234
- npcpy/modes/serve.py +0 -1637
- npcpy/modes/spool.py +0 -312
- npcpy/modes/wander.py +0 -549
- npcpy/modes/yap.py +0 -572
- npcpy/npc_team/alicanto.npc +0 -2
- npcpy/npc_team/alicanto.png +0 -0
- npcpy/npc_team/assembly_lines/test_pipeline.py +0 -181
- npcpy/npc_team/corca.npc +0 -13
- npcpy/npc_team/foreman.npc +0 -7
- npcpy/npc_team/frederic.npc +0 -6
- npcpy/npc_team/frederic4.png +0 -0
- npcpy/npc_team/guac.png +0 -0
- npcpy/npc_team/jinxs/automator.jinx +0 -18
- npcpy/npc_team/jinxs/bash_executer.jinx +0 -31
- npcpy/npc_team/jinxs/calculator.jinx +0 -11
- npcpy/npc_team/jinxs/edit_file.jinx +0 -96
- npcpy/npc_team/jinxs/file_chat.jinx +0 -14
- npcpy/npc_team/jinxs/gui_controller.jinx +0 -28
- npcpy/npc_team/jinxs/image_generation.jinx +0 -29
- npcpy/npc_team/jinxs/internet_search.jinx +0 -30
- npcpy/npc_team/jinxs/local_search.jinx +0 -152
- npcpy/npc_team/jinxs/npcsh_executor.jinx +0 -31
- npcpy/npc_team/jinxs/python_executor.jinx +0 -8
- npcpy/npc_team/jinxs/screen_cap.jinx +0 -25
- npcpy/npc_team/jinxs/sql_executor.jinx +0 -33
- npcpy/npc_team/kadiefa.npc +0 -3
- npcpy/npc_team/kadiefa.png +0 -0
- npcpy/npc_team/npcsh.ctx +0 -9
- npcpy/npc_team/npcsh_sibiji.png +0 -0
- npcpy/npc_team/plonk.npc +0 -2
- npcpy/npc_team/plonk.png +0 -0
- npcpy/npc_team/plonkjr.npc +0 -2
- npcpy/npc_team/plonkjr.png +0 -0
- npcpy/npc_team/sibiji.npc +0 -5
- npcpy/npc_team/sibiji.png +0 -0
- npcpy/npc_team/spool.png +0 -0
- npcpy/npc_team/templates/analytics/celona.npc +0 -0
- npcpy/npc_team/templates/hr_support/raone.npc +0 -0
- npcpy/npc_team/templates/humanities/eriane.npc +0 -4
- npcpy/npc_team/templates/it_support/lineru.npc +0 -0
- npcpy/npc_team/templates/marketing/slean.npc +0 -4
- npcpy/npc_team/templates/philosophy/maurawa.npc +0 -0
- npcpy/npc_team/templates/sales/turnic.npc +0 -4
- npcpy/npc_team/templates/software/welxor.npc +0 -0
- npcpy/npc_team/yap.png +0 -0
- npcpy/routes.py +0 -958
- npcpy/work/mcp_helpers.py +0 -357
- npcpy/work/mcp_server.py +0 -194
- npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/automator.jinx +0 -18
- npcpy-1.0.26.data/data/npcpy/npc_team/bash_executer.jinx +0 -31
- npcpy-1.0.26.data/data/npcpy/npc_team/calculator.jinx +0 -11
- npcpy-1.0.26.data/data/npcpy/npc_team/celona.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/corca.npc +0 -13
- npcpy-1.0.26.data/data/npcpy/npc_team/edit_file.jinx +0 -96
- npcpy-1.0.26.data/data/npcpy/npc_team/eriane.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/file_chat.jinx +0 -14
- npcpy-1.0.26.data/data/npcpy/npc_team/foreman.npc +0 -7
- npcpy-1.0.26.data/data/npcpy/npc_team/frederic.npc +0 -6
- npcpy-1.0.26.data/data/npcpy/npc_team/frederic4.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/guac.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/gui_controller.jinx +0 -28
- npcpy-1.0.26.data/data/npcpy/npc_team/image_generation.jinx +0 -29
- npcpy-1.0.26.data/data/npcpy/npc_team/internet_search.jinx +0 -30
- npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.npc +0 -3
- npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/lineru.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/local_search.jinx +0 -152
- npcpy-1.0.26.data/data/npcpy/npc_team/maurawa.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh.ctx +0 -9
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_executor.jinx +0 -31
- npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_sibiji.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/plonk.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/plonk.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.npc +0 -2
- npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/python_executor.jinx +0 -8
- npcpy-1.0.26.data/data/npcpy/npc_team/raone.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/screen_cap.jinx +0 -25
- npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.npc +0 -5
- npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/slean.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/spool.png +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/sql_executor.jinx +0 -33
- npcpy-1.0.26.data/data/npcpy/npc_team/test_pipeline.py +0 -181
- npcpy-1.0.26.data/data/npcpy/npc_team/turnic.npc +0 -4
- npcpy-1.0.26.data/data/npcpy/npc_team/welxor.npc +0 -0
- npcpy-1.0.26.data/data/npcpy/npc_team/yap.png +0 -0
- npcpy-1.0.26.dist-info/METADATA +0 -827
- npcpy-1.0.26.dist-info/RECORD +0 -139
- npcpy-1.0.26.dist-info/entry_points.txt +0 -11
- /npcpy/{modes → ft}/__init__.py +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
npcpy/data/text.py
CHANGED
|
@@ -37,70 +37,70 @@ def rag_search(
|
|
|
37
37
|
)
|
|
38
38
|
results = []
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
|
|
41
41
|
query_embedding = embedding_model.encode(
|
|
42
42
|
query, convert_to_tensor=True, show_progress_bar=False
|
|
43
43
|
)
|
|
44
44
|
if isinstance(text_data, str):
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
lines = text_data.split(".")
|
|
47
47
|
if not lines:
|
|
48
48
|
return results
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
if text_data_embedded is None:
|
|
51
51
|
line_embeddings = embedding_model.encode(lines, convert_to_tensor=True)
|
|
52
52
|
else:
|
|
53
53
|
line_embeddings = text_data_embedded
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
cosine_scores = util.cos_sim(query_embedding, line_embeddings)[0].cpu().numpy()
|
|
56
56
|
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
relevant_line_indices = np.where(cosine_scores >= similarity_threshold)[0]
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
|
|
62
62
|
|
|
63
63
|
for idx in relevant_line_indices:
|
|
64
64
|
idx = int(idx)
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
start_idx = max(0, idx - 10)
|
|
67
|
-
end_idx = min(len(lines), idx + 11)
|
|
67
|
+
end_idx = min(len(lines), idx + 11)
|
|
68
68
|
snippet = ". ".join(lines[start_idx:end_idx])
|
|
69
69
|
results.append(snippet)
|
|
70
70
|
|
|
71
71
|
elif isinstance(text_data, dict):
|
|
72
72
|
for filename, content in text_data.items():
|
|
73
|
-
|
|
73
|
+
|
|
74
74
|
lines = content.split("\n")
|
|
75
75
|
if not lines:
|
|
76
76
|
continue
|
|
77
|
-
|
|
77
|
+
|
|
78
78
|
if text_data_embedded is None:
|
|
79
79
|
line_embeddings = embedding_model.encode(lines, convert_to_tensor=True)
|
|
80
80
|
else:
|
|
81
81
|
line_embeddings = text_data_embedded[filename]
|
|
82
|
-
|
|
82
|
+
|
|
83
83
|
cosine_scores = (
|
|
84
84
|
util.cos_sim(query_embedding, line_embeddings)[0].cpu().numpy()
|
|
85
85
|
)
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
|
|
90
90
|
relevant_line_indices = np.where(cosine_scores >= similarity_threshold)[0]
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
|
|
94
94
|
for idx in relevant_line_indices:
|
|
95
|
-
idx = int(idx)
|
|
96
|
-
|
|
95
|
+
idx = int(idx)
|
|
96
|
+
|
|
97
97
|
start_idx = max(0, idx - 10)
|
|
98
98
|
end_idx = min(
|
|
99
99
|
len(lines), idx + 11
|
|
100
|
-
)
|
|
100
|
+
)
|
|
101
101
|
snippet = "\n".join(lines[start_idx:end_idx])
|
|
102
102
|
results.append((filename, snippet))
|
|
103
|
-
|
|
103
|
+
|
|
104
104
|
return results
|
|
105
105
|
|
|
106
106
|
|
|
@@ -122,10 +122,10 @@ def load_all_files(
|
|
|
122
122
|
"""
|
|
123
123
|
text_data = {}
|
|
124
124
|
if depth < 1:
|
|
125
|
-
return text_data
|
|
125
|
+
return text_data
|
|
126
126
|
|
|
127
127
|
if extensions is None:
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
extensions = [
|
|
130
130
|
".txt",
|
|
131
131
|
".md",
|
|
@@ -139,11 +139,11 @@ def load_all_files(
|
|
|
139
139
|
".ts",
|
|
140
140
|
".tsx",
|
|
141
141
|
".npc",
|
|
142
|
-
|
|
142
|
+
|
|
143
143
|
]
|
|
144
144
|
|
|
145
145
|
try:
|
|
146
|
-
|
|
146
|
+
|
|
147
147
|
entries = os.listdir(directory)
|
|
148
148
|
except Exception as e:
|
|
149
149
|
print(f"Could not list directory {directory}: {e}")
|
|
@@ -159,7 +159,7 @@ def load_all_files(
|
|
|
159
159
|
except Exception as e:
|
|
160
160
|
print(f"Could not read file {path}: {e}")
|
|
161
161
|
elif os.path.isdir(path):
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
subdir_data = load_all_files(path, extensions, depth=depth - 1)
|
|
164
164
|
text_data.update(subdir_data)
|
|
165
165
|
|
npcpy/data/video.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
|
|
1
|
+
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def process_video(file_path, table_name):
|
|
5
|
-
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import base64
|
|
6
8
|
|
|
7
9
|
embeddings = []
|
|
8
10
|
texts = []
|
|
@@ -16,37 +18,11 @@ def process_video(file_path, table_name):
|
|
|
16
18
|
if not ret:
|
|
17
19
|
break
|
|
18
20
|
|
|
19
|
-
|
|
20
|
-
n = 10
|
|
21
|
-
|
|
22
|
-
# Image Embeddings
|
|
23
|
-
_, buffer = cv2.imencode(".jpg", frame) # Encode frame as JPG
|
|
24
|
-
base64_image = base64.b64encode(buffer).decode("utf-8")
|
|
25
|
-
image_info = {
|
|
26
|
-
"filename": f"frame_{i}.jpg",
|
|
27
|
-
"file_path": f"data:image/jpeg;base64,{base64_image}",
|
|
28
|
-
} # Use data URL for OpenAI
|
|
29
|
-
image_embedding_response = get_llm_response(
|
|
30
|
-
"Describe this image.",
|
|
31
|
-
image=image_info,
|
|
32
|
-
model="gpt-4",
|
|
33
|
-
provider="openai",
|
|
34
|
-
) # Replace with your image embedding model
|
|
35
|
-
if (
|
|
36
|
-
isinstance(image_embedding_response, dict)
|
|
37
|
-
and "error" in image_embedding_response
|
|
38
|
-
):
|
|
39
|
-
print(
|
|
40
|
-
f"Error generating image embedding: {image_embedding_response['error']}"
|
|
41
|
-
)
|
|
42
|
-
else:
|
|
43
|
-
# Assuming your image embedding model returns a textual description
|
|
44
|
-
embeddings.append(image_embedding_response)
|
|
45
|
-
texts.append(f"Frame {i}: {image_embedding_response}")
|
|
46
|
-
|
|
47
|
-
video.release()
|
|
21
|
+
|
|
22
|
+
n = 10
|
|
23
|
+
|
|
48
24
|
return embeddings, texts
|
|
49
25
|
|
|
50
26
|
except Exception as e:
|
|
51
27
|
print(f"Error processing video: {e}")
|
|
52
|
-
return [], []
|
|
28
|
+
return [], []
|
npcpy/data/web.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
|
|
2
2
|
|
|
3
3
|
import requests
|
|
4
4
|
import os
|
|
@@ -21,6 +21,25 @@ except:
|
|
|
21
21
|
pass
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def search_exa(query:str,
|
|
28
|
+
api_key:str = None,
|
|
29
|
+
top_k = 5,
|
|
30
|
+
**kwargs):
|
|
31
|
+
from exa_py import Exa
|
|
32
|
+
if api_key is None:
|
|
33
|
+
api_key = os.environ.get('EXA_API_KEY')
|
|
34
|
+
exa = Exa(api_key)
|
|
35
|
+
|
|
36
|
+
results = exa.search_and_contents(
|
|
37
|
+
query,
|
|
38
|
+
text=True
|
|
39
|
+
)
|
|
40
|
+
return results.results[0:top_k]
|
|
41
|
+
|
|
42
|
+
|
|
24
43
|
def search_perplexity(
|
|
25
44
|
query: str,
|
|
26
45
|
api_key: str = None,
|
|
@@ -30,8 +49,11 @@ def search_perplexity(
|
|
|
30
49
|
top_p: float = 0.9,
|
|
31
50
|
):
|
|
32
51
|
if api_key is None:
|
|
33
|
-
api_key = os.environ
|
|
34
|
-
|
|
52
|
+
api_key = os.environ.get("PERPLEXITY_API_KEY")
|
|
53
|
+
if api_key is None:
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
|
|
35
57
|
url = "https://api.perplexity.ai/chat/completions"
|
|
36
58
|
payload = {
|
|
37
59
|
"model": "sonar",
|
|
@@ -52,13 +74,17 @@ def search_perplexity(
|
|
|
52
74
|
"response_format": None,
|
|
53
75
|
}
|
|
54
76
|
|
|
55
|
-
|
|
56
|
-
headers = {"Authorization": f"Bearer {api_key}",
|
|
77
|
+
|
|
78
|
+
headers = {"Authorization": f"Bearer {api_key}",
|
|
79
|
+
"Content-Type": "application/json"}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
response = requests.post(url,
|
|
83
|
+
json=payload,
|
|
84
|
+
headers=headers)
|
|
85
|
+
|
|
86
|
+
response = response.json()
|
|
57
87
|
|
|
58
|
-
# Make the POST request to the API
|
|
59
|
-
response = requests.post(url, json=payload, headers=headers)
|
|
60
|
-
response = json.loads(response.text)
|
|
61
|
-
#print(response)
|
|
62
88
|
return [response["choices"][0]["message"]["content"], response["citations"]]
|
|
63
89
|
|
|
64
90
|
|
|
@@ -68,7 +94,7 @@ def search_web(
|
|
|
68
94
|
provider: str=None,
|
|
69
95
|
api_key=None,
|
|
70
96
|
perplexity_kwargs: Optional[Dict[str, Any]] = None,
|
|
71
|
-
) -> List
|
|
97
|
+
) -> List:
|
|
72
98
|
"""
|
|
73
99
|
Function Description:
|
|
74
100
|
This function searches the web for information based on a query.
|
|
@@ -76,7 +102,7 @@ def search_web(
|
|
|
76
102
|
query: The search query.
|
|
77
103
|
Keyword Args:
|
|
78
104
|
num_results: The number of search results to retrieve.
|
|
79
|
-
provider: The search engine provider to use ('
|
|
105
|
+
provider: The search engine provider to use ('perplexity' or 'duckduckgo').
|
|
80
106
|
Returns:
|
|
81
107
|
A list of dictionaries with 'title', 'link', and 'content' keys.
|
|
82
108
|
"""
|
|
@@ -88,7 +114,7 @@ def search_web(
|
|
|
88
114
|
|
|
89
115
|
if provider == "perplexity":
|
|
90
116
|
search_result = search_perplexity(query, api_key=api_key, **perplexity_kwargs)
|
|
91
|
-
|
|
117
|
+
|
|
92
118
|
return search_result
|
|
93
119
|
|
|
94
120
|
if provider == "duckduckgo":
|
|
@@ -108,30 +134,32 @@ def search_web(
|
|
|
108
134
|
print("DuckDuckGo search failed: ", e)
|
|
109
135
|
urls = []
|
|
110
136
|
results = []
|
|
137
|
+
elif provider =='exa':
|
|
138
|
+
return search_exa(query, api_key=api_key, )
|
|
111
139
|
|
|
112
|
-
elif provider =='google':
|
|
140
|
+
elif provider =='google':
|
|
113
141
|
urls = list(search(query, num_results=num_results))
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
|
|
117
145
|
for url in urls:
|
|
118
146
|
try:
|
|
119
|
-
|
|
147
|
+
|
|
120
148
|
headers = {
|
|
121
149
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
122
150
|
}
|
|
123
151
|
response = requests.get(url, headers=headers, timeout=5)
|
|
124
152
|
response.raise_for_status()
|
|
125
153
|
|
|
126
|
-
|
|
154
|
+
|
|
127
155
|
soup = BeautifulSoup(response.text, "html.parser")
|
|
128
156
|
|
|
129
|
-
|
|
157
|
+
|
|
130
158
|
title = soup.title.string if soup.title else url
|
|
131
159
|
|
|
132
|
-
|
|
160
|
+
|
|
133
161
|
content = " ".join([p.get_text() for p in soup.find_all("p")])
|
|
134
|
-
content = " ".join(content.split())
|
|
162
|
+
content = " ".join(content.split())
|
|
135
163
|
|
|
136
164
|
results.append(
|
|
137
165
|
{
|
|
@@ -147,8 +175,8 @@ def search_web(
|
|
|
147
175
|
print(f"Error fetching {url}: {str(e)}")
|
|
148
176
|
continue
|
|
149
177
|
|
|
150
|
-
|
|
151
|
-
|
|
178
|
+
|
|
179
|
+
|
|
152
180
|
content_str = "\n".join(
|
|
153
181
|
[r["content"] + "\n Citation: " + r["link"] + "\n\n\n" for r in results]
|
|
154
182
|
)
|
npcpy/ft/diff.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# finetuning diffuser models
|
|
2
|
+
try:
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch.utils.data import DataLoader, Dataset as TorchDataset
|
|
7
|
+
from transformers import CLIPTextModel, CLIPTokenizer
|
|
8
|
+
except:
|
|
9
|
+
torch = None
|
|
10
|
+
nn = None
|
|
11
|
+
F = None
|
|
12
|
+
DataLoader = None
|
|
13
|
+
TorchDataset = None
|
|
14
|
+
CLIPTextModel = None
|
|
15
|
+
CLIPTokenizer = None
|
|
16
|
+
import math
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import List, Optional, Callable
|
|
19
|
+
import numpy as np
|
|
20
|
+
from PIL import Image
|
|
21
|
+
import os
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
import gc
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DiffusionConfig:
|
|
28
|
+
image_size: int = 128
|
|
29
|
+
channels: int = 256
|
|
30
|
+
time_emb_dim: int = 128
|
|
31
|
+
timesteps: int = 1000
|
|
32
|
+
beta_start: float = 1e-4
|
|
33
|
+
beta_end: float = 0.02
|
|
34
|
+
num_epochs: int = 100
|
|
35
|
+
batch_size: int = 4
|
|
36
|
+
learning_rate: float = 1e-5
|
|
37
|
+
checkpoint_frequency: int = 1000
|
|
38
|
+
output_dir: str = "diffusion_model"
|
|
39
|
+
use_clip: bool = True
|
|
40
|
+
num_channels: int = 1
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SinusoidalPositionEmbeddings(nn.Module):
|
|
44
|
+
|
|
45
|
+
def __init__(self, dim):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.dim = dim
|
|
48
|
+
|
|
49
|
+
def forward(self, time):
|
|
50
|
+
device = time.device
|
|
51
|
+
half_dim = self.dim // 2
|
|
52
|
+
embeddings = math.log(10000) / (half_dim - 1)
|
|
53
|
+
embeddings = torch.exp(
|
|
54
|
+
torch.arange(half_dim, device=device) * -embeddings
|
|
55
|
+
)
|
|
56
|
+
embeddings = time[:, None] * embeddings[None, :]
|
|
57
|
+
embeddings = torch.cat(
|
|
58
|
+
(embeddings.sin(), embeddings.cos()),
|
|
59
|
+
dim=-1
|
|
60
|
+
)
|
|
61
|
+
return embeddings
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SimpleUNet(nn.Module):
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
image_size=128,
|
|
69
|
+
channels=256,
|
|
70
|
+
time_emb_dim=128,
|
|
71
|
+
num_channels=1
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
self.image_size = image_size
|
|
76
|
+
|
|
77
|
+
self.time_mlp = nn.Sequential(
|
|
78
|
+
SinusoidalPositionEmbeddings(time_emb_dim),
|
|
79
|
+
nn.Linear(time_emb_dim, time_emb_dim * 4),
|
|
80
|
+
nn.GELU(),
|
|
81
|
+
nn.Linear(time_emb_dim * 4, channels),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
self.text_mlp = nn.Sequential(
|
|
85
|
+
nn.Linear(768, time_emb_dim),
|
|
86
|
+
nn.GELU(),
|
|
87
|
+
nn.Linear(time_emb_dim, time_emb_dim),
|
|
88
|
+
nn.GELU(),
|
|
89
|
+
nn.Linear(time_emb_dim, channels),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.conv_in = nn.Conv2d(num_channels, channels, 1, padding=0)
|
|
93
|
+
|
|
94
|
+
self.down1 = nn.Sequential(
|
|
95
|
+
nn.Conv2d(channels, channels * 2, 4, 2, 1),
|
|
96
|
+
nn.GroupNorm(8, channels * 2),
|
|
97
|
+
nn.GELU(),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.down2 = nn.Sequential(
|
|
101
|
+
nn.Conv2d(channels * 2, channels * 4, 4, 2, 1),
|
|
102
|
+
nn.GroupNorm(8, channels * 4),
|
|
103
|
+
nn.GELU(),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.down3 = nn.Sequential(
|
|
107
|
+
nn.Conv2d(channels * 4, channels * 8, 4, 2, 1),
|
|
108
|
+
nn.GroupNorm(8, channels * 8),
|
|
109
|
+
nn.GELU(),
|
|
110
|
+
)
|
npcpy/ft/ge.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Optional, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class GAConfig:
|
|
8
|
+
population_size: int = 20
|
|
9
|
+
mutation_rate: float = 0.15
|
|
10
|
+
crossover_rate: float = 0.7
|
|
11
|
+
tournament_size: int = 3
|
|
12
|
+
elitism_count: int = 2
|
|
13
|
+
generations: int = 50
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GeneticEvolver:
|
|
17
|
+
"""
|
|
18
|
+
Generic GA that takes fitness, mutation, crossover
|
|
19
|
+
and initialization functions to evolve any population
|
|
20
|
+
"""
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
fitness_fn: Callable,
|
|
24
|
+
mutate_fn: Callable,
|
|
25
|
+
crossover_fn: Callable,
|
|
26
|
+
initialize_fn: Callable,
|
|
27
|
+
config: Optional[GAConfig] = None
|
|
28
|
+
):
|
|
29
|
+
self.fitness_fn = fitness_fn
|
|
30
|
+
self.mutate_fn = mutate_fn
|
|
31
|
+
self.crossover_fn = crossover_fn
|
|
32
|
+
self.initialize_fn = initialize_fn
|
|
33
|
+
self.config = config or GAConfig()
|
|
34
|
+
self.population = []
|
|
35
|
+
self.history = []
|
|
36
|
+
|
|
37
|
+
def initialize_population(self):
|
|
38
|
+
self.population = [
|
|
39
|
+
self.initialize_fn()
|
|
40
|
+
for _ in range(self.config.population_size)
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
def evaluate_population(self) -> List[float]:
|
|
44
|
+
return [
|
|
45
|
+
self.fitness_fn(individual)
|
|
46
|
+
for individual in self.population
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
def tournament_select(self, fitness_scores: List[float]):
|
|
50
|
+
indices = random.sample(
|
|
51
|
+
range(len(self.population)),
|
|
52
|
+
self.config.tournament_size
|
|
53
|
+
)
|
|
54
|
+
tournament_fitness = [fitness_scores[i] for i in indices]
|
|
55
|
+
winner_idx = indices[
|
|
56
|
+
tournament_fitness.index(max(tournament_fitness))
|
|
57
|
+
]
|
|
58
|
+
return self.population[winner_idx]
|
|
59
|
+
|
|
60
|
+
def evolve_generation(self):
|
|
61
|
+
fitness_scores = self.evaluate_population()
|
|
62
|
+
|
|
63
|
+
sorted_pop = sorted(
|
|
64
|
+
zip(self.population, fitness_scores),
|
|
65
|
+
key=lambda x: x[1],
|
|
66
|
+
reverse=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
new_population = [
|
|
70
|
+
ind for ind, _ in sorted_pop[:self.config.elitism_count]
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
while len(new_population) < self.config.population_size:
|
|
74
|
+
parent1 = self.tournament_select(fitness_scores)
|
|
75
|
+
parent2 = self.tournament_select(fitness_scores)
|
|
76
|
+
|
|
77
|
+
if random.random() < self.config.crossover_rate:
|
|
78
|
+
child = self.crossover_fn(parent1, parent2)
|
|
79
|
+
else:
|
|
80
|
+
child = parent1
|
|
81
|
+
|
|
82
|
+
if random.random() < self.config.mutation_rate:
|
|
83
|
+
child = self.mutate_fn(child)
|
|
84
|
+
|
|
85
|
+
new_population.append(child)
|
|
86
|
+
|
|
87
|
+
self.population = new_population[:self.config.population_size]
|
|
88
|
+
|
|
89
|
+
best_fitness = max(fitness_scores)
|
|
90
|
+
avg_fitness = sum(fitness_scores) / len(fitness_scores)
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
'best_fitness': best_fitness,
|
|
94
|
+
'avg_fitness': avg_fitness,
|
|
95
|
+
'best_individual': sorted_pop[0][0]
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
def run(self, generations: Optional[int] = None):
|
|
99
|
+
if not self.population:
|
|
100
|
+
self.initialize_population()
|
|
101
|
+
|
|
102
|
+
gens = generations or self.config.generations
|
|
103
|
+
|
|
104
|
+
for gen in range(gens):
|
|
105
|
+
gen_stats = self.evolve_generation()
|
|
106
|
+
self.history.append(gen_stats)
|
|
107
|
+
|
|
108
|
+
if gen % 10 == 0:
|
|
109
|
+
print(
|
|
110
|
+
f"Gen {gen}: "
|
|
111
|
+
f"Best={gen_stats['best_fitness']:.3f}, "
|
|
112
|
+
f"Avg={gen_stats['avg_fitness']:.3f}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return self.history[-1]['best_individual']
|