npcpy 1.0.26__py3-none-any.whl → 1.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. npcpy/__init__.py +0 -7
  2. npcpy/data/audio.py +16 -99
  3. npcpy/data/image.py +43 -42
  4. npcpy/data/load.py +83 -124
  5. npcpy/data/text.py +28 -28
  6. npcpy/data/video.py +8 -32
  7. npcpy/data/web.py +51 -23
  8. npcpy/ft/diff.py +110 -0
  9. npcpy/ft/ge.py +115 -0
  10. npcpy/ft/memory_trainer.py +171 -0
  11. npcpy/ft/model_ensembler.py +357 -0
  12. npcpy/ft/rl.py +360 -0
  13. npcpy/ft/sft.py +248 -0
  14. npcpy/ft/usft.py +128 -0
  15. npcpy/gen/audio_gen.py +24 -0
  16. npcpy/gen/embeddings.py +13 -13
  17. npcpy/gen/image_gen.py +262 -117
  18. npcpy/gen/response.py +615 -415
  19. npcpy/gen/video_gen.py +53 -7
  20. npcpy/llm_funcs.py +1869 -437
  21. npcpy/main.py +1 -1
  22. npcpy/memory/command_history.py +844 -510
  23. npcpy/memory/kg_vis.py +833 -0
  24. npcpy/memory/knowledge_graph.py +892 -1845
  25. npcpy/memory/memory_processor.py +81 -0
  26. npcpy/memory/search.py +188 -90
  27. npcpy/mix/debate.py +192 -3
  28. npcpy/npc_compiler.py +1672 -801
  29. npcpy/npc_sysenv.py +593 -1266
  30. npcpy/serve.py +3120 -0
  31. npcpy/sql/ai_function_tools.py +257 -0
  32. npcpy/sql/database_ai_adapters.py +186 -0
  33. npcpy/sql/database_ai_functions.py +163 -0
  34. npcpy/sql/model_runner.py +19 -19
  35. npcpy/sql/npcsql.py +706 -507
  36. npcpy/sql/sql_model_compiler.py +156 -0
  37. npcpy/tools.py +183 -0
  38. npcpy/work/plan.py +13 -279
  39. npcpy/work/trigger.py +3 -3
  40. npcpy-1.2.32.dist-info/METADATA +803 -0
  41. npcpy-1.2.32.dist-info/RECORD +54 -0
  42. npcpy/data/dataframes.py +0 -171
  43. npcpy/memory/deep_research.py +0 -125
  44. npcpy/memory/sleep.py +0 -557
  45. npcpy/modes/_state.py +0 -78
  46. npcpy/modes/alicanto.py +0 -1075
  47. npcpy/modes/guac.py +0 -785
  48. npcpy/modes/mcp_npcsh.py +0 -822
  49. npcpy/modes/npc.py +0 -213
  50. npcpy/modes/npcsh.py +0 -1158
  51. npcpy/modes/plonk.py +0 -409
  52. npcpy/modes/pti.py +0 -234
  53. npcpy/modes/serve.py +0 -1637
  54. npcpy/modes/spool.py +0 -312
  55. npcpy/modes/wander.py +0 -549
  56. npcpy/modes/yap.py +0 -572
  57. npcpy/npc_team/alicanto.npc +0 -2
  58. npcpy/npc_team/alicanto.png +0 -0
  59. npcpy/npc_team/assembly_lines/test_pipeline.py +0 -181
  60. npcpy/npc_team/corca.npc +0 -13
  61. npcpy/npc_team/foreman.npc +0 -7
  62. npcpy/npc_team/frederic.npc +0 -6
  63. npcpy/npc_team/frederic4.png +0 -0
  64. npcpy/npc_team/guac.png +0 -0
  65. npcpy/npc_team/jinxs/automator.jinx +0 -18
  66. npcpy/npc_team/jinxs/bash_executer.jinx +0 -31
  67. npcpy/npc_team/jinxs/calculator.jinx +0 -11
  68. npcpy/npc_team/jinxs/edit_file.jinx +0 -96
  69. npcpy/npc_team/jinxs/file_chat.jinx +0 -14
  70. npcpy/npc_team/jinxs/gui_controller.jinx +0 -28
  71. npcpy/npc_team/jinxs/image_generation.jinx +0 -29
  72. npcpy/npc_team/jinxs/internet_search.jinx +0 -30
  73. npcpy/npc_team/jinxs/local_search.jinx +0 -152
  74. npcpy/npc_team/jinxs/npcsh_executor.jinx +0 -31
  75. npcpy/npc_team/jinxs/python_executor.jinx +0 -8
  76. npcpy/npc_team/jinxs/screen_cap.jinx +0 -25
  77. npcpy/npc_team/jinxs/sql_executor.jinx +0 -33
  78. npcpy/npc_team/kadiefa.npc +0 -3
  79. npcpy/npc_team/kadiefa.png +0 -0
  80. npcpy/npc_team/npcsh.ctx +0 -9
  81. npcpy/npc_team/npcsh_sibiji.png +0 -0
  82. npcpy/npc_team/plonk.npc +0 -2
  83. npcpy/npc_team/plonk.png +0 -0
  84. npcpy/npc_team/plonkjr.npc +0 -2
  85. npcpy/npc_team/plonkjr.png +0 -0
  86. npcpy/npc_team/sibiji.npc +0 -5
  87. npcpy/npc_team/sibiji.png +0 -0
  88. npcpy/npc_team/spool.png +0 -0
  89. npcpy/npc_team/templates/analytics/celona.npc +0 -0
  90. npcpy/npc_team/templates/hr_support/raone.npc +0 -0
  91. npcpy/npc_team/templates/humanities/eriane.npc +0 -4
  92. npcpy/npc_team/templates/it_support/lineru.npc +0 -0
  93. npcpy/npc_team/templates/marketing/slean.npc +0 -4
  94. npcpy/npc_team/templates/philosophy/maurawa.npc +0 -0
  95. npcpy/npc_team/templates/sales/turnic.npc +0 -4
  96. npcpy/npc_team/templates/software/welxor.npc +0 -0
  97. npcpy/npc_team/yap.png +0 -0
  98. npcpy/routes.py +0 -958
  99. npcpy/work/mcp_helpers.py +0 -357
  100. npcpy/work/mcp_server.py +0 -194
  101. npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.npc +0 -2
  102. npcpy-1.0.26.data/data/npcpy/npc_team/alicanto.png +0 -0
  103. npcpy-1.0.26.data/data/npcpy/npc_team/automator.jinx +0 -18
  104. npcpy-1.0.26.data/data/npcpy/npc_team/bash_executer.jinx +0 -31
  105. npcpy-1.0.26.data/data/npcpy/npc_team/calculator.jinx +0 -11
  106. npcpy-1.0.26.data/data/npcpy/npc_team/celona.npc +0 -0
  107. npcpy-1.0.26.data/data/npcpy/npc_team/corca.npc +0 -13
  108. npcpy-1.0.26.data/data/npcpy/npc_team/edit_file.jinx +0 -96
  109. npcpy-1.0.26.data/data/npcpy/npc_team/eriane.npc +0 -4
  110. npcpy-1.0.26.data/data/npcpy/npc_team/file_chat.jinx +0 -14
  111. npcpy-1.0.26.data/data/npcpy/npc_team/foreman.npc +0 -7
  112. npcpy-1.0.26.data/data/npcpy/npc_team/frederic.npc +0 -6
  113. npcpy-1.0.26.data/data/npcpy/npc_team/frederic4.png +0 -0
  114. npcpy-1.0.26.data/data/npcpy/npc_team/guac.png +0 -0
  115. npcpy-1.0.26.data/data/npcpy/npc_team/gui_controller.jinx +0 -28
  116. npcpy-1.0.26.data/data/npcpy/npc_team/image_generation.jinx +0 -29
  117. npcpy-1.0.26.data/data/npcpy/npc_team/internet_search.jinx +0 -30
  118. npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.npc +0 -3
  119. npcpy-1.0.26.data/data/npcpy/npc_team/kadiefa.png +0 -0
  120. npcpy-1.0.26.data/data/npcpy/npc_team/lineru.npc +0 -0
  121. npcpy-1.0.26.data/data/npcpy/npc_team/local_search.jinx +0 -152
  122. npcpy-1.0.26.data/data/npcpy/npc_team/maurawa.npc +0 -0
  123. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh.ctx +0 -9
  124. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_executor.jinx +0 -31
  125. npcpy-1.0.26.data/data/npcpy/npc_team/npcsh_sibiji.png +0 -0
  126. npcpy-1.0.26.data/data/npcpy/npc_team/plonk.npc +0 -2
  127. npcpy-1.0.26.data/data/npcpy/npc_team/plonk.png +0 -0
  128. npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.npc +0 -2
  129. npcpy-1.0.26.data/data/npcpy/npc_team/plonkjr.png +0 -0
  130. npcpy-1.0.26.data/data/npcpy/npc_team/python_executor.jinx +0 -8
  131. npcpy-1.0.26.data/data/npcpy/npc_team/raone.npc +0 -0
  132. npcpy-1.0.26.data/data/npcpy/npc_team/screen_cap.jinx +0 -25
  133. npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.npc +0 -5
  134. npcpy-1.0.26.data/data/npcpy/npc_team/sibiji.png +0 -0
  135. npcpy-1.0.26.data/data/npcpy/npc_team/slean.npc +0 -4
  136. npcpy-1.0.26.data/data/npcpy/npc_team/spool.png +0 -0
  137. npcpy-1.0.26.data/data/npcpy/npc_team/sql_executor.jinx +0 -33
  138. npcpy-1.0.26.data/data/npcpy/npc_team/test_pipeline.py +0 -181
  139. npcpy-1.0.26.data/data/npcpy/npc_team/turnic.npc +0 -4
  140. npcpy-1.0.26.data/data/npcpy/npc_team/welxor.npc +0 -0
  141. npcpy-1.0.26.data/data/npcpy/npc_team/yap.png +0 -0
  142. npcpy-1.0.26.dist-info/METADATA +0 -827
  143. npcpy-1.0.26.dist-info/RECORD +0 -139
  144. npcpy-1.0.26.dist-info/entry_points.txt +0 -11
  145. /npcpy/{modes → ft}/__init__.py +0 -0
  146. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
  147. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
  148. {npcpy-1.0.26.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
npcpy/data/text.py CHANGED
@@ -37,70 +37,70 @@ def rag_search(
37
37
  )
38
38
  results = []
39
39
 
40
- # Compute the embedding of the query
40
+
41
41
  query_embedding = embedding_model.encode(
42
42
  query, convert_to_tensor=True, show_progress_bar=False
43
43
  )
44
44
  if isinstance(text_data, str):
45
- # split at the sentence level
45
+
46
46
  lines = text_data.split(".")
47
47
  if not lines:
48
48
  return results
49
- # Compute embeddings for each line
49
+
50
50
  if text_data_embedded is None:
51
51
  line_embeddings = embedding_model.encode(lines, convert_to_tensor=True)
52
52
  else:
53
53
  line_embeddings = text_data_embedded
54
- # Compute cosine similarities
54
+
55
55
  cosine_scores = util.cos_sim(query_embedding, line_embeddings)[0].cpu().numpy()
56
56
 
57
- # Find indices of lines above the similarity threshold
57
+
58
58
  relevant_line_indices = np.where(cosine_scores >= similarity_threshold)[0]
59
- # print("relevant_line_indices", cosine_scores)
60
- # print(np.mean(cosine_scores))
61
- # print(np.max(cosine_scores))
59
+
60
+
61
+
62
62
 
63
63
  for idx in relevant_line_indices:
64
64
  idx = int(idx)
65
- # Get context lines (±10 lines)
65
+
66
66
  start_idx = max(0, idx - 10)
67
- end_idx = min(len(lines), idx + 11) # +11 because end index is exclusive
67
+ end_idx = min(len(lines), idx + 11)
68
68
  snippet = ". ".join(lines[start_idx:end_idx])
69
69
  results.append(snippet)
70
70
 
71
71
  elif isinstance(text_data, dict):
72
72
  for filename, content in text_data.items():
73
- # Split content into lines
73
+
74
74
  lines = content.split("\n")
75
75
  if not lines:
76
76
  continue
77
- # Compute embeddings for each line
77
+
78
78
  if text_data_embedded is None:
79
79
  line_embeddings = embedding_model.encode(lines, convert_to_tensor=True)
80
80
  else:
81
81
  line_embeddings = text_data_embedded[filename]
82
- # Compute cosine similarities
82
+
83
83
  cosine_scores = (
84
84
  util.cos_sim(query_embedding, line_embeddings)[0].cpu().numpy()
85
85
  )
86
86
 
87
- # Find indices of lines above the similarity threshold
88
- ##print("most similar", np.max(cosine_scores))
89
- ##print("most similar doc", lines[np.argmax(cosine_scores)])
87
+
88
+
89
+
90
90
  relevant_line_indices = np.where(cosine_scores >= similarity_threshold)[0]
91
- # print("relevant_line_indices", cosine_scores)
92
- # print(np.mean(cosine_scores))
93
- # print(np.max(cosine_scores))
91
+
92
+
93
+
94
94
  for idx in relevant_line_indices:
95
- idx = int(idx) # Ensure idx is an integer
96
- # Get context lines (±10 lines)
95
+ idx = int(idx)
96
+
97
97
  start_idx = max(0, idx - 10)
98
98
  end_idx = min(
99
99
  len(lines), idx + 11
100
- ) # +11 because end index is exclusive
100
+ )
101
101
  snippet = "\n".join(lines[start_idx:end_idx])
102
102
  results.append((filename, snippet))
103
- # print("results", results)
103
+
104
104
  return results
105
105
 
106
106
 
@@ -122,10 +122,10 @@ def load_all_files(
122
122
  """
123
123
  text_data = {}
124
124
  if depth < 1:
125
- return text_data # Reached the specified depth, stop recursion.
125
+ return text_data
126
126
 
127
127
  if extensions is None:
128
- # Default to common text file extensions
128
+
129
129
  extensions = [
130
130
  ".txt",
131
131
  ".md",
@@ -139,11 +139,11 @@ def load_all_files(
139
139
  ".ts",
140
140
  ".tsx",
141
141
  ".npc",
142
- # Add more extensions if needed
142
+
143
143
  ]
144
144
 
145
145
  try:
146
- # List all entries in the directory
146
+
147
147
  entries = os.listdir(directory)
148
148
  except Exception as e:
149
149
  print(f"Could not list directory {directory}: {e}")
@@ -159,7 +159,7 @@ def load_all_files(
159
159
  except Exception as e:
160
160
  print(f"Could not read file {path}: {e}")
161
161
  elif os.path.isdir(path):
162
- # Recurse into subdirectories, decreasing depth by 1
162
+
163
163
  subdir_data = load_all_files(path, extensions, depth=depth - 1)
164
164
  text_data.update(subdir_data)
165
165
 
npcpy/data/video.py CHANGED
@@ -1,8 +1,10 @@
1
- # video.py
1
+
2
2
 
3
3
 
4
4
  def process_video(file_path, table_name):
5
- # implement with moon dream
5
+
6
+ import cv2
7
+ import base64
6
8
 
7
9
  embeddings = []
8
10
  texts = []
@@ -16,37 +18,11 @@ def process_video(file_path, table_name):
16
18
  if not ret:
17
19
  break
18
20
 
19
- # Process every nth frame (adjust n as needed for performance)
20
- n = 10 # Process every 10th frame
21
- if i % n == 0:
22
- # Image Embeddings
23
- _, buffer = cv2.imencode(".jpg", frame) # Encode frame as JPG
24
- base64_image = base64.b64encode(buffer).decode("utf-8")
25
- image_info = {
26
- "filename": f"frame_{i}.jpg",
27
- "file_path": f"data:image/jpeg;base64,{base64_image}",
28
- } # Use data URL for OpenAI
29
- image_embedding_response = get_llm_response(
30
- "Describe this image.",
31
- image=image_info,
32
- model="gpt-4",
33
- provider="openai",
34
- ) # Replace with your image embedding model
35
- if (
36
- isinstance(image_embedding_response, dict)
37
- and "error" in image_embedding_response
38
- ):
39
- print(
40
- f"Error generating image embedding: {image_embedding_response['error']}"
41
- )
42
- else:
43
- # Assuming your image embedding model returns a textual description
44
- embeddings.append(image_embedding_response)
45
- texts.append(f"Frame {i}: {image_embedding_response}")
46
-
47
- video.release()
21
+
22
+ n = 10
23
+
48
24
  return embeddings, texts
49
25
 
50
26
  except Exception as e:
51
27
  print(f"Error processing video: {e}")
52
- return [], [] # Return empty lists in case of error
28
+ return [], []
npcpy/data/web.py CHANGED
@@ -1,4 +1,4 @@
1
- # search.py
1
+
2
2
 
3
3
  import requests
4
4
  import os
@@ -21,6 +21,25 @@ except:
21
21
  pass
22
22
 
23
23
 
24
+
25
+
26
+
27
+ def search_exa(query:str,
28
+ api_key:str = None,
29
+ top_k = 5,
30
+ **kwargs):
31
+ from exa_py import Exa
32
+ if api_key is None:
33
+ api_key = os.environ.get('EXA_API_KEY')
34
+ exa = Exa(api_key)
35
+
36
+ results = exa.search_and_contents(
37
+ query,
38
+ text=True
39
+ )
40
+ return results.results[0:top_k]
41
+
42
+
24
43
  def search_perplexity(
25
44
  query: str,
26
45
  api_key: str = None,
@@ -30,8 +49,11 @@ def search_perplexity(
30
49
  top_p: float = 0.9,
31
50
  ):
32
51
  if api_key is None:
33
- api_key = os.environ["PERPLEXITY_API_KEY"]
34
- # print("api_key", api_key)
52
+ api_key = os.environ.get("PERPLEXITY_API_KEY")
53
+ if api_key is None:
54
+ raise
55
+
56
+
35
57
  url = "https://api.perplexity.ai/chat/completions"
36
58
  payload = {
37
59
  "model": "sonar",
@@ -52,13 +74,17 @@ def search_perplexity(
52
74
  "response_format": None,
53
75
  }
54
76
 
55
- # Headers for the request, including the Authorization bearer token
56
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
77
+
78
+ headers = {"Authorization": f"Bearer {api_key}",
79
+ "Content-Type": "application/json"}
80
+
81
+
82
+ response = requests.post(url,
83
+ json=payload,
84
+ headers=headers)
85
+
86
+ response = response.json()
57
87
 
58
- # Make the POST request to the API
59
- response = requests.post(url, json=payload, headers=headers)
60
- response = json.loads(response.text)
61
- #print(response)
62
88
  return [response["choices"][0]["message"]["content"], response["citations"]]
63
89
 
64
90
 
@@ -68,7 +94,7 @@ def search_web(
68
94
  provider: str=None,
69
95
  api_key=None,
70
96
  perplexity_kwargs: Optional[Dict[str, Any]] = None,
71
- ) -> List[Dict[str, str]]:
97
+ ) -> List:
72
98
  """
73
99
  Function Description:
74
100
  This function searches the web for information based on a query.
@@ -76,7 +102,7 @@ def search_web(
76
102
  query: The search query.
77
103
  Keyword Args:
78
104
  num_results: The number of search results to retrieve.
79
- provider: The search engine provider to use ('google' or 'duckduckgo').
105
+ provider: The search engine provider to use ('perplexity' or 'duckduckgo').
80
106
  Returns:
81
107
  A list of dictionaries with 'title', 'link', and 'content' keys.
82
108
  """
@@ -88,7 +114,7 @@ def search_web(
88
114
 
89
115
  if provider == "perplexity":
90
116
  search_result = search_perplexity(query, api_key=api_key, **perplexity_kwargs)
91
- # print(search_result, type(search_result))
117
+
92
118
  return search_result
93
119
 
94
120
  if provider == "duckduckgo":
@@ -108,30 +134,32 @@ def search_web(
108
134
  print("DuckDuckGo search failed: ", e)
109
135
  urls = []
110
136
  results = []
137
+ elif provider =='exa':
138
+ return search_exa(query, api_key=api_key, )
111
139
 
112
- elif provider =='google': # google
140
+ elif provider =='google':
113
141
  urls = list(search(query, num_results=num_results))
114
- # google shit doesnt seem to be working anymore, apparently a lbock they made on browsers without js?
115
- #print("urls", urls)
116
- #print(provider)
142
+
143
+
144
+
117
145
  for url in urls:
118
146
  try:
119
- # Fetch the webpage content
147
+
120
148
  headers = {
121
149
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
122
150
  }
123
151
  response = requests.get(url, headers=headers, timeout=5)
124
152
  response.raise_for_status()
125
153
 
126
- # Parse with BeautifulSoup
154
+
127
155
  soup = BeautifulSoup(response.text, "html.parser")
128
156
 
129
- # Get title and content
157
+
130
158
  title = soup.title.string if soup.title else url
131
159
 
132
- # Extract text content and clean it up
160
+
133
161
  content = " ".join([p.get_text() for p in soup.find_all("p")])
134
- content = " ".join(content.split()) # Clean up whitespace
162
+ content = " ".join(content.split())
135
163
 
136
164
  results.append(
137
165
  {
@@ -147,8 +175,8 @@ def search_web(
147
175
  print(f"Error fetching {url}: {str(e)}")
148
176
  continue
149
177
 
150
- # except Exception as e:
151
- # print(f"Search error: {str(e)}")
178
+
179
+
152
180
  content_str = "\n".join(
153
181
  [r["content"] + "\n Citation: " + r["link"] + "\n\n\n" for r in results]
154
182
  )
npcpy/ft/diff.py ADDED
@@ -0,0 +1,110 @@
1
+ # finetuning diffuser models
2
+ try:
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader, Dataset as TorchDataset
7
+ from transformers import CLIPTextModel, CLIPTokenizer
8
+ except:
9
+ torch = None
10
+ nn = None
11
+ F = None
12
+ DataLoader = None
13
+ TorchDataset = None
14
+ CLIPTextModel = None
15
+ CLIPTokenizer = None
16
+ import math
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Optional, Callable
19
+ import numpy as np
20
+ from PIL import Image
21
+ import os
22
+ from tqdm import tqdm
23
+ import gc
24
+
25
+
26
+ @dataclass
27
+ class DiffusionConfig:
28
+ image_size: int = 128
29
+ channels: int = 256
30
+ time_emb_dim: int = 128
31
+ timesteps: int = 1000
32
+ beta_start: float = 1e-4
33
+ beta_end: float = 0.02
34
+ num_epochs: int = 100
35
+ batch_size: int = 4
36
+ learning_rate: float = 1e-5
37
+ checkpoint_frequency: int = 1000
38
+ output_dir: str = "diffusion_model"
39
+ use_clip: bool = True
40
+ num_channels: int = 1
41
+
42
+
43
+ class SinusoidalPositionEmbeddings(nn.Module):
44
+
45
+ def __init__(self, dim):
46
+ super().__init__()
47
+ self.dim = dim
48
+
49
+ def forward(self, time):
50
+ device = time.device
51
+ half_dim = self.dim // 2
52
+ embeddings = math.log(10000) / (half_dim - 1)
53
+ embeddings = torch.exp(
54
+ torch.arange(half_dim, device=device) * -embeddings
55
+ )
56
+ embeddings = time[:, None] * embeddings[None, :]
57
+ embeddings = torch.cat(
58
+ (embeddings.sin(), embeddings.cos()),
59
+ dim=-1
60
+ )
61
+ return embeddings
62
+
63
+
64
+ class SimpleUNet(nn.Module):
65
+
66
+ def __init__(
67
+ self,
68
+ image_size=128,
69
+ channels=256,
70
+ time_emb_dim=128,
71
+ num_channels=1
72
+ ):
73
+ super().__init__()
74
+
75
+ self.image_size = image_size
76
+
77
+ self.time_mlp = nn.Sequential(
78
+ SinusoidalPositionEmbeddings(time_emb_dim),
79
+ nn.Linear(time_emb_dim, time_emb_dim * 4),
80
+ nn.GELU(),
81
+ nn.Linear(time_emb_dim * 4, channels),
82
+ )
83
+
84
+ self.text_mlp = nn.Sequential(
85
+ nn.Linear(768, time_emb_dim),
86
+ nn.GELU(),
87
+ nn.Linear(time_emb_dim, time_emb_dim),
88
+ nn.GELU(),
89
+ nn.Linear(time_emb_dim, channels),
90
+ )
91
+
92
+ self.conv_in = nn.Conv2d(num_channels, channels, 1, padding=0)
93
+
94
+ self.down1 = nn.Sequential(
95
+ nn.Conv2d(channels, channels * 2, 4, 2, 1),
96
+ nn.GroupNorm(8, channels * 2),
97
+ nn.GELU(),
98
+ )
99
+
100
+ self.down2 = nn.Sequential(
101
+ nn.Conv2d(channels * 2, channels * 4, 4, 2, 1),
102
+ nn.GroupNorm(8, channels * 4),
103
+ nn.GELU(),
104
+ )
105
+
106
+ self.down3 = nn.Sequential(
107
+ nn.Conv2d(channels * 4, channels * 8, 4, 2, 1),
108
+ nn.GroupNorm(8, channels * 8),
109
+ nn.GELU(),
110
+ )
npcpy/ft/ge.py ADDED
@@ -0,0 +1,115 @@
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, List
4
+
5
+
6
+ @dataclass
7
+ class GAConfig:
8
+ population_size: int = 20
9
+ mutation_rate: float = 0.15
10
+ crossover_rate: float = 0.7
11
+ tournament_size: int = 3
12
+ elitism_count: int = 2
13
+ generations: int = 50
14
+
15
+
16
+ class GeneticEvolver:
17
+ """
18
+ Generic GA that takes fitness, mutation, crossover
19
+ and initialization functions to evolve any population
20
+ """
21
+ def __init__(
22
+ self,
23
+ fitness_fn: Callable,
24
+ mutate_fn: Callable,
25
+ crossover_fn: Callable,
26
+ initialize_fn: Callable,
27
+ config: Optional[GAConfig] = None
28
+ ):
29
+ self.fitness_fn = fitness_fn
30
+ self.mutate_fn = mutate_fn
31
+ self.crossover_fn = crossover_fn
32
+ self.initialize_fn = initialize_fn
33
+ self.config = config or GAConfig()
34
+ self.population = []
35
+ self.history = []
36
+
37
+ def initialize_population(self):
38
+ self.population = [
39
+ self.initialize_fn()
40
+ for _ in range(self.config.population_size)
41
+ ]
42
+
43
+ def evaluate_population(self) -> List[float]:
44
+ return [
45
+ self.fitness_fn(individual)
46
+ for individual in self.population
47
+ ]
48
+
49
+ def tournament_select(self, fitness_scores: List[float]):
50
+ indices = random.sample(
51
+ range(len(self.population)),
52
+ self.config.tournament_size
53
+ )
54
+ tournament_fitness = [fitness_scores[i] for i in indices]
55
+ winner_idx = indices[
56
+ tournament_fitness.index(max(tournament_fitness))
57
+ ]
58
+ return self.population[winner_idx]
59
+
60
+ def evolve_generation(self):
61
+ fitness_scores = self.evaluate_population()
62
+
63
+ sorted_pop = sorted(
64
+ zip(self.population, fitness_scores),
65
+ key=lambda x: x[1],
66
+ reverse=True
67
+ )
68
+
69
+ new_population = [
70
+ ind for ind, _ in sorted_pop[:self.config.elitism_count]
71
+ ]
72
+
73
+ while len(new_population) < self.config.population_size:
74
+ parent1 = self.tournament_select(fitness_scores)
75
+ parent2 = self.tournament_select(fitness_scores)
76
+
77
+ if random.random() < self.config.crossover_rate:
78
+ child = self.crossover_fn(parent1, parent2)
79
+ else:
80
+ child = parent1
81
+
82
+ if random.random() < self.config.mutation_rate:
83
+ child = self.mutate_fn(child)
84
+
85
+ new_population.append(child)
86
+
87
+ self.population = new_population[:self.config.population_size]
88
+
89
+ best_fitness = max(fitness_scores)
90
+ avg_fitness = sum(fitness_scores) / len(fitness_scores)
91
+
92
+ return {
93
+ 'best_fitness': best_fitness,
94
+ 'avg_fitness': avg_fitness,
95
+ 'best_individual': sorted_pop[0][0]
96
+ }
97
+
98
+ def run(self, generations: Optional[int] = None):
99
+ if not self.population:
100
+ self.initialize_population()
101
+
102
+ gens = generations or self.config.generations
103
+
104
+ for gen in range(gens):
105
+ gen_stats = self.evolve_generation()
106
+ self.history.append(gen_stats)
107
+
108
+ if gen % 10 == 0:
109
+ print(
110
+ f"Gen {gen}: "
111
+ f"Best={gen_stats['best_fitness']:.3f}, "
112
+ f"Avg={gen_stats['avg_fitness']:.3f}"
113
+ )
114
+
115
+ return self.history[-1]['best_individual']