jax-hpc-profiler 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -1,28 +1,28 @@
1
1
  import os
2
2
  import time
3
3
  from functools import partial
4
- from typing import Any, Callable, List, Optional, Tuple
4
+ from typing import Any, Callable, Optional, Tuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
9
  from jax import make_jaxpr
10
- from jax.experimental import mesh_utils
11
10
  from jax.experimental.shard_map import shard_map
12
- from jax.sharding import Mesh, NamedSharding
11
+ from jax.sharding import NamedSharding
13
12
  from jax.sharding import PartitionSpec as P
14
13
  from jaxtyping import Array
15
14
  from tabulate import tabulate
16
15
 
17
16
 
18
17
  class Timer:
19
-
20
- def __init__(self,
21
- save_jaxpr=False,
22
- compile_info=True,
23
- jax_fn=True,
24
- devices=None,
25
- static_argnums=()):
18
+ def __init__(
19
+ self,
20
+ save_jaxpr=False,
21
+ compile_info=True,
22
+ jax_fn=True,
23
+ devices=None,
24
+ static_argnums=(),
25
+ ):
26
26
  self.jit_time = 0.0
27
27
  self.times = []
28
28
  self.profiling_data = {}
@@ -34,16 +34,14 @@ class Timer:
34
34
  self.static_argnums = static_argnums
35
35
 
36
36
  def _normalize_memory_units(self, memory_analysis) -> str:
37
-
38
37
  if not (self.jax_fn and self.compile_info):
39
38
  return memory_analysis
40
39
 
41
- sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
40
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
42
41
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
43
- factor = 0 if memory_analysis == 0 else int(
44
- np.log10(memory_analysis) // 3)
42
+ factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
45
43
 
46
- return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
44
+ return f'{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}'
47
45
 
48
46
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
49
47
  if memory_analysis is None:
@@ -64,37 +62,33 @@ class Timer:
64
62
  if isinstance(x, Array):
65
63
  x.block_until_ready()
66
64
 
67
- jax.tree_map(_block, out)
65
+ jax.tree.map(_block, out)
68
66
  end = time.perf_counter()
69
67
  self.jit_time = (end - start) * 1e3
70
68
 
71
- self.compiled_code["JAXPR"] = "N/A"
72
- self.compiled_code["LOWERED"] = "N/A"
73
- self.compiled_code["COMPILED"] = "N/A"
74
- self.profiling_data["generated_code"] = "N/A"
75
- self.profiling_data["argument_size"] = "N/A"
76
- self.profiling_data["output_size"] = "N/A"
77
- self.profiling_data["temp_size"] = "N/A"
69
+ self.compiled_code['JAXPR'] = 'N/A'
70
+ self.compiled_code['LOWERED'] = 'N/A'
71
+ self.compiled_code['COMPILED'] = 'N/A'
72
+ self.profiling_data['generated_code'] = 'N/A'
73
+ self.profiling_data['argument_size'] = 'N/A'
74
+ self.profiling_data['output_size'] = 'N/A'
75
+ self.profiling_data['temp_size'] = 'N/A'
78
76
 
79
77
  if self.save_jaxpr:
80
- jaxpr = make_jaxpr(fun,
81
- static_argnums=self.static_argnums)(*args,
82
- **kwargs)
83
- self.compiled_code["JAXPR"] = jaxpr.pretty_print()
78
+ jaxpr = make_jaxpr(fun, static_argnums=self.static_argnums)(*args, **kwargs)
79
+ self.compiled_code['JAXPR'] = jaxpr.pretty_print()
84
80
 
85
81
  if self.jax_fn and self.compile_info:
86
- lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
87
- *args, **kwargs)
82
+ lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(*args, **kwargs)
88
83
  compiled = lowered.compile()
89
- memory_analysis = self._read_memory_analysis(
90
- compiled.memory_analysis())
84
+ memory_analysis = self._read_memory_analysis(compiled.memory_analysis())
91
85
 
92
- self.compiled_code["LOWERED"] = lowered.as_text()
93
- self.compiled_code["COMPILED"] = compiled.as_text()
94
- self.profiling_data["generated_code"] = memory_analysis[0]
95
- self.profiling_data["argument_size"] = memory_analysis[1]
96
- self.profiling_data["output_size"] = memory_analysis[2]
97
- self.profiling_data["temp_size"] = memory_analysis[3]
86
+ self.compiled_code['LOWERED'] = lowered.as_text()
87
+ self.compiled_code['COMPILED'] = compiled.as_text()
88
+ self.profiling_data['generated_code'] = memory_analysis[0]
89
+ self.profiling_data['argument_size'] = memory_analysis[1]
90
+ self.profiling_data['output_size'] = memory_analysis[2]
91
+ self.profiling_data['temp_size'] = memory_analysis[3]
98
92
 
99
93
  return out
100
94
 
@@ -107,7 +101,7 @@ class Timer:
107
101
  if isinstance(x, Array):
108
102
  x.block_until_ready()
109
103
 
110
- jax.tree_map(_block, out)
104
+ jax.tree.map(_block, out)
111
105
  end = time.perf_counter()
112
106
  self.times.append((end - start) * 1e3)
113
107
  return out
@@ -119,9 +113,8 @@ class Timer:
119
113
  if self.devices is None:
120
114
  self.devices = jax.devices()
121
115
 
122
- mesh = jax.make_mesh((len(self.devices), ), ("x", ),
123
- devices=self.devices)
124
- sharding = NamedSharding(mesh, P("x"))
116
+ mesh = jax.make_mesh((len(self.devices),), ('x',), devices=self.devices)
117
+ sharding = NamedSharding(mesh, P('x'))
125
118
 
126
119
  times_array = jnp.array(self.times)
127
120
  global_shape = (jax.device_count(), times_array.shape[0])
@@ -131,13 +124,9 @@ class Timer:
131
124
  data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
132
125
  )
133
126
 
134
- @partial(shard_map,
135
- mesh=mesh,
136
- in_specs=P("x"),
137
- out_specs=P(),
138
- check_rep=False)
139
- def get_mean_times(times):
140
- return jax.lax.pmean(times, axis_name="x")
127
+ @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P(), check_rep=False)
128
+ def get_mean_times(times) -> Array:
129
+ return jax.lax.pmean(times, axis_name='x')
141
130
 
142
131
  times_array = get_mean_times(global_times)
143
132
  times_array.block_until_ready()
@@ -150,17 +139,17 @@ class Timer:
150
139
  x: int,
151
140
  y: int | None = None,
152
141
  z: int | None = None,
153
- precision: str = "float32",
142
+ precision: str = 'float32',
154
143
  px: int = 1,
155
144
  py: int = 1,
156
- backend: str = "NCCL",
145
+ backend: str = 'NCCL',
157
146
  nodes: int = 1,
158
147
  md_filename: str | None = None,
159
- npz_data: Optional[dict] = None,
160
- extra_info: dict = {},
161
- ):
148
+ npz_data: Optional[dict[str, Any]] = None,
149
+ extra_info: dict[str, Any] = {},
150
+ ) -> None:
162
151
  if self.jit_time == 0.0 and len(self.times) == 0:
163
- print(f"No profiling data to report for {function}")
152
+ print(f'No profiling data to report for {function}')
164
153
  return
165
154
 
166
155
  if md_filename is None:
@@ -168,22 +157,18 @@ class Timer:
168
157
  os.path.dirname(csv_filename),
169
158
  os.path.splitext(os.path.basename(csv_filename))[0],
170
159
  )
171
- report_folder = filename if dirname == "" else f"{dirname}/{filename}"
160
+ report_folder = filename if dirname == '' else f'{dirname}/{filename}'
172
161
  os.makedirs(report_folder, exist_ok=True)
173
- md_filename = (
174
- f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
175
- )
162
+ md_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md'
176
163
 
177
164
  if npz_data is not None:
178
165
  dirname, filename = (
179
166
  os.path.dirname(csv_filename),
180
167
  os.path.splitext(os.path.basename(csv_filename))[0],
181
168
  )
182
- report_folder = filename if dirname == "" else f"{dirname}/{filename}"
169
+ report_folder = filename if dirname == '' else f'{dirname}/{filename}'
183
170
  os.makedirs(report_folder, exist_ok=True)
184
- npz_filename = (
185
- f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
186
- )
171
+ npz_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz'
187
172
  np.savez(npz_filename, **npz_data)
188
173
 
189
174
  y = x if y is None else y
@@ -191,96 +176,98 @@ class Timer:
191
176
 
192
177
  times_array = self._get_mean_times()
193
178
  if jax.process_index() == 0:
194
-
195
179
  min_time = np.min(times_array)
196
180
  max_time = np.max(times_array)
197
181
  mean_time = np.mean(times_array)
198
182
  std_time = np.std(times_array)
199
183
  last_time = times_array[-1]
200
- generated_code = self.profiling_data["generated_code"]
201
- argument_size = self.profiling_data["argument_size"]
202
- output_size = self.profiling_data["output_size"]
203
- temp_size = self.profiling_data["temp_size"]
184
+ generated_code = self.profiling_data['generated_code']
185
+ argument_size = self.profiling_data['argument_size']
186
+ output_size = self.profiling_data['output_size']
187
+ temp_size = self.profiling_data['temp_size']
204
188
 
205
189
  csv_line = (
206
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
207
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
208
- f"{generated_code},{argument_size},{output_size},{temp_size}\n"
190
+ f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
191
+ f'{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},'
192
+ f'{generated_code},{argument_size},{output_size},{temp_size}\n'
209
193
  )
210
194
 
211
- with open(csv_filename, "a") as f:
195
+ with open(csv_filename, 'a') as f:
212
196
  f.write(csv_line)
213
197
 
214
198
  param_dict = {
215
- "Function": function,
216
- "Precision": precision,
217
- "X": x,
218
- "Y": y,
219
- "Z": z,
220
- "PX": px,
221
- "PY": py,
222
- "Backend": backend,
223
- "Nodes": nodes,
199
+ 'Function': function,
200
+ 'Precision': precision,
201
+ 'X': x,
202
+ 'Y': y,
203
+ 'Z': z,
204
+ 'PX': px,
205
+ 'PY': py,
206
+ 'Backend': backend,
207
+ 'Nodes': nodes,
224
208
  }
225
209
  param_dict.update(extra_info)
226
210
  profiling_result = {
227
- "JIT Time": self.jit_time,
228
- "Min Time": min_time,
229
- "Max Time": max_time,
230
- "Mean Time": mean_time,
231
- "Std Time": std_time,
232
- "Last Time": last_time,
233
- "Generated Code": self._normalize_memory_units(generated_code),
234
- "Argument Size": self._normalize_memory_units(argument_size),
235
- "Output Size": self._normalize_memory_units(output_size),
236
- "Temporary Size": self._normalize_memory_units(temp_size),
211
+ 'JIT Time': self.jit_time,
212
+ 'Min Time': min_time,
213
+ 'Max Time': max_time,
214
+ 'Mean Time': mean_time,
215
+ 'Std Time': std_time,
216
+ 'Last Time': last_time,
217
+ 'Generated Code': self._normalize_memory_units(generated_code),
218
+ 'Argument Size': self._normalize_memory_units(argument_size),
219
+ 'Output Size': self._normalize_memory_units(output_size),
220
+ 'Temporary Size': self._normalize_memory_units(temp_size),
237
221
  }
238
222
  iteration_runs = {}
239
223
  for i in range(len(times_array)):
240
- iteration_runs[f"Run {i}"] = times_array[i]
224
+ iteration_runs[f'Run {i}'] = times_array[i]
241
225
 
242
- with open(md_filename, "w") as f:
243
- f.write(f"# Reporting for {function}\n")
244
- f.write(f"## Parameters\n")
226
+ with open(md_filename, 'w') as f:
227
+ f.write(f'# Reporting for {function}\n')
228
+ f.write('## Parameters\n')
245
229
  f.write(
246
230
  tabulate(
247
231
  param_dict.items(),
248
- headers=["Parameter", "Value"],
249
- tablefmt="github",
250
- ))
251
- f.write("\n---\n")
252
- f.write(f"## Profiling Data\n")
232
+ headers=['Parameter', 'Value'],
233
+ tablefmt='github',
234
+ )
235
+ )
236
+ f.write('\n---\n')
237
+ f.write('## Profiling Data\n')
253
238
  f.write(
254
239
  tabulate(
255
240
  profiling_result.items(),
256
- headers=["Parameter", "Value"],
257
- tablefmt="github",
258
- ))
259
- f.write("\n---\n")
260
- f.write(f"## Iteration Runs\n")
241
+ headers=['Parameter', 'Value'],
242
+ tablefmt='github',
243
+ )
244
+ )
245
+ f.write('\n---\n')
246
+ f.write('## Iteration Runs\n')
261
247
  f.write(
262
248
  tabulate(
263
249
  iteration_runs.items(),
264
- headers=["Iteration", "Time"],
265
- tablefmt="github",
266
- ))
250
+ headers=['Iteration', 'Time'],
251
+ tablefmt='github',
252
+ )
253
+ )
267
254
  if self.jax_fn and self.compile_info:
268
- f.write("\n---\n")
269
- f.write(f"## Compiled Code\n")
270
- f.write(f"```hlo\n")
271
- f.write(self.compiled_code["COMPILED"])
272
- f.write(f"\n```\n")
273
- f.write("\n---\n")
274
- f.write(f"## Lowered Code\n")
275
- f.write(f"```hlo\n")
276
- f.write(self.compiled_code["LOWERED"])
277
- f.write(f"\n```\n")
278
- f.write("\n---\n")
255
+ f.write('\n---\n')
256
+ f.write('## Compiled Code\n')
257
+ f.write('```hlo\n')
258
+ f.write(self.compiled_code['COMPILED'])
259
+ f.write('\n```\n')
260
+ f.write('\n---\n')
261
+ f.write('## Lowered Code\n')
262
+ f.write('```hlo\n')
263
+ f.write(self.compiled_code['LOWERED'])
264
+ f.write('\n```\n')
265
+ f.write('\n---\n')
279
266
  if self.save_jaxpr:
280
- f.write(f"## JAXPR\n")
281
- f.write(f"```haskel\n")
282
- f.write(self.compiled_code["JAXPR"])
283
- f.write(f"\n```\n")
267
+ f.write('## JAXPR\n')
268
+ f.write('```haskel\n')
269
+ f.write(self.compiled_code['JAXPR'])
270
+ f.write('\n```\n')
284
271
 
285
272
  # Reset the timer
286
273
  self.jit_time = 0.0