jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -1,32 +1,36 @@
1
1
  import os
2
2
  import time
3
3
  from functools import partial
4
- from typing import Any, Callable, List, Optional, Tuple
4
+ from typing import Any, Callable, Optional, Tuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from jax import make_jaxpr
10
- from jax.experimental import mesh_utils
11
- from jax.experimental.shard_map import shard_map
12
- from jax.sharding import Mesh, NamedSharding
9
+ from jax import make_jaxpr, shard_map
10
+ from jax.sharding import NamedSharding
13
11
  from jax.sharding import PartitionSpec as P
14
12
  from jaxtyping import Array
15
13
  from tabulate import tabulate
16
14
 
17
15
 
18
16
  class Timer:
19
-
20
- def __init__(self,
21
- save_jaxpr=False,
22
- compile_info=True,
23
- jax_fn=True,
24
- devices=None,
25
- static_argnums=()):
17
+ def __init__(
18
+ self,
19
+ save_jaxpr=False,
20
+ compile_info=True,
21
+ jax_fn=True,
22
+ devices=None,
23
+ static_argnums=(),
24
+ ):
26
25
  self.jit_time = 0.0
27
26
  self.times = []
28
- self.profiling_data = {}
29
- self.compiled_code = {}
27
+ self.profiling_data = {
28
+ 'generated_code': 'N/A',
29
+ 'argument_size': 'N/A',
30
+ 'output_size': 'N/A',
31
+ 'temp_size': 'N/A',
32
+ }
33
+ self.compiled_code = {'JAXPR': 'N/A', 'LOWERED': 'N/A', 'COMPILED': 'N/A'}
30
34
  self.save_jaxpr = save_jaxpr
31
35
  self.compile_info = compile_info
32
36
  self.jax_fn = jax_fn
@@ -34,16 +38,14 @@ class Timer:
34
38
  self.static_argnums = static_argnums
35
39
 
36
40
  def _normalize_memory_units(self, memory_analysis) -> str:
37
-
38
41
  if not (self.jax_fn and self.compile_info):
39
42
  return memory_analysis
40
43
 
41
- sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
44
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
42
45
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
43
- factor = 0 if memory_analysis == 0 else int(
44
- np.log10(memory_analysis) // 3)
46
+ factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
45
47
 
46
- return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
48
+ return f'{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}'
47
49
 
48
50
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
49
51
  if memory_analysis is None:
@@ -64,37 +66,33 @@ class Timer:
64
66
  if isinstance(x, Array):
65
67
  x.block_until_ready()
66
68
 
67
- jax.tree_map(_block, out)
69
+ jax.tree.map(_block, out)
68
70
  end = time.perf_counter()
69
71
  self.jit_time = (end - start) * 1e3
70
72
 
71
- self.compiled_code["JAXPR"] = "N/A"
72
- self.compiled_code["LOWERED"] = "N/A"
73
- self.compiled_code["COMPILED"] = "N/A"
74
- self.profiling_data["generated_code"] = "N/A"
75
- self.profiling_data["argument_size"] = "N/A"
76
- self.profiling_data["output_size"] = "N/A"
77
- self.profiling_data["temp_size"] = "N/A"
73
+ self.compiled_code['JAXPR'] = 'N/A'
74
+ self.compiled_code['LOWERED'] = 'N/A'
75
+ self.compiled_code['COMPILED'] = 'N/A'
76
+ self.profiling_data['generated_code'] = 'N/A'
77
+ self.profiling_data['argument_size'] = 'N/A'
78
+ self.profiling_data['output_size'] = 'N/A'
79
+ self.profiling_data['temp_size'] = 'N/A'
78
80
 
79
81
  if self.save_jaxpr:
80
- jaxpr = make_jaxpr(fun,
81
- static_argnums=self.static_argnums)(*args,
82
- **kwargs)
83
- self.compiled_code["JAXPR"] = jaxpr.pretty_print()
82
+ jaxpr = make_jaxpr(fun, static_argnums=self.static_argnums)(*args, **kwargs)
83
+ self.compiled_code['JAXPR'] = jaxpr.pretty_print()
84
84
 
85
85
  if self.jax_fn and self.compile_info:
86
- lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
87
- *args, **kwargs)
86
+ lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(*args, **kwargs)
88
87
  compiled = lowered.compile()
89
- memory_analysis = self._read_memory_analysis(
90
- compiled.memory_analysis())
88
+ memory_analysis = self._read_memory_analysis(compiled.memory_analysis())
91
89
 
92
- self.compiled_code["LOWERED"] = lowered.as_text()
93
- self.compiled_code["COMPILED"] = compiled.as_text()
94
- self.profiling_data["generated_code"] = memory_analysis[0]
95
- self.profiling_data["argument_size"] = memory_analysis[1]
96
- self.profiling_data["output_size"] = memory_analysis[2]
97
- self.profiling_data["temp_size"] = memory_analysis[3]
90
+ self.compiled_code['LOWERED'] = lowered.as_text()
91
+ self.compiled_code['COMPILED'] = compiled.as_text()
92
+ self.profiling_data['generated_code'] = memory_analysis[0]
93
+ self.profiling_data['argument_size'] = memory_analysis[1]
94
+ self.profiling_data['output_size'] = memory_analysis[2]
95
+ self.profiling_data['temp_size'] = memory_analysis[3]
98
96
 
99
97
  return out
100
98
 
@@ -107,7 +105,7 @@ class Timer:
107
105
  if isinstance(x, Array):
108
106
  x.block_until_ready()
109
107
 
110
- jax.tree_map(_block, out)
108
+ jax.tree.map(_block, out)
111
109
  end = time.perf_counter()
112
110
  self.times.append((end - start) * 1e3)
113
111
  return out
@@ -119,9 +117,8 @@ class Timer:
119
117
  if self.devices is None:
120
118
  self.devices = jax.devices()
121
119
 
122
- mesh = jax.make_mesh((len(self.devices), ), ("x", ),
123
- devices=self.devices)
124
- sharding = NamedSharding(mesh, P("x"))
120
+ mesh = jax.make_mesh((len(self.devices),), ('x',), devices=self.devices)
121
+ sharding = NamedSharding(mesh, P('x'))
125
122
 
126
123
  times_array = jnp.array(self.times)
127
124
  global_shape = (jax.device_count(), times_array.shape[0])
@@ -131,13 +128,9 @@ class Timer:
131
128
  data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
132
129
  )
133
130
 
134
- @partial(shard_map,
135
- mesh=mesh,
136
- in_specs=P("x"),
137
- out_specs=P(),
138
- check_rep=False)
139
- def get_mean_times(times):
140
- return jax.lax.pmean(times, axis_name="x")
131
+ @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P(), check_rep=False)
132
+ def get_mean_times(times) -> Array:
133
+ return jax.lax.pmean(times, axis_name='x')
141
134
 
142
135
  times_array = get_mean_times(global_times)
143
136
  times_array.block_until_ready()
@@ -150,17 +143,17 @@ class Timer:
150
143
  x: int,
151
144
  y: int | None = None,
152
145
  z: int | None = None,
153
- precision: str = "float32",
146
+ precision: str = 'float32',
154
147
  px: int = 1,
155
148
  py: int = 1,
156
- backend: str = "NCCL",
149
+ backend: str = 'NCCL',
157
150
  nodes: int = 1,
158
151
  md_filename: str | None = None,
159
- npz_data: Optional[dict] = None,
160
- extra_info: dict = {},
161
- ):
152
+ npz_data: Optional[dict[str, Any]] = None,
153
+ extra_info: dict[str, Any] = {},
154
+ ) -> None:
162
155
  if self.jit_time == 0.0 and len(self.times) == 0:
163
- print(f"No profiling data to report for {function}")
156
+ print(f'No profiling data to report for {function}')
164
157
  return
165
158
 
166
159
  if md_filename is None:
@@ -168,22 +161,18 @@ class Timer:
168
161
  os.path.dirname(csv_filename),
169
162
  os.path.splitext(os.path.basename(csv_filename))[0],
170
163
  )
171
- report_folder = filename if dirname == "" else f"{dirname}/{filename}"
164
+ report_folder = filename if dirname == '' else f'{dirname}/{filename}'
172
165
  os.makedirs(report_folder, exist_ok=True)
173
- md_filename = (
174
- f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
175
- )
166
+ md_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md'
176
167
 
177
168
  if npz_data is not None:
178
169
  dirname, filename = (
179
170
  os.path.dirname(csv_filename),
180
171
  os.path.splitext(os.path.basename(csv_filename))[0],
181
172
  )
182
- report_folder = filename if dirname == "" else f"{dirname}/{filename}"
173
+ report_folder = filename if dirname == '' else f'{dirname}/{filename}'
183
174
  os.makedirs(report_folder, exist_ok=True)
184
- npz_filename = (
185
- f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
186
- )
175
+ npz_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz'
187
176
  np.savez(npz_filename, **npz_data)
188
177
 
189
178
  y = x if y is None else y
@@ -191,96 +180,98 @@ class Timer:
191
180
 
192
181
  times_array = self._get_mean_times()
193
182
  if jax.process_index() == 0:
194
-
195
183
  min_time = np.min(times_array)
196
184
  max_time = np.max(times_array)
197
185
  mean_time = np.mean(times_array)
198
186
  std_time = np.std(times_array)
199
187
  last_time = times_array[-1]
200
- generated_code = self.profiling_data["generated_code"]
201
- argument_size = self.profiling_data["argument_size"]
202
- output_size = self.profiling_data["output_size"]
203
- temp_size = self.profiling_data["temp_size"]
188
+ generated_code = self.profiling_data.get('generated_code', 'N/A')
189
+ argument_size = self.profiling_data.get('argument_size', 'N/A')
190
+ output_size = self.profiling_data.get('output_size', 'N/A')
191
+ temp_size = self.profiling_data.get('temp_size', 'N/A')
204
192
 
205
193
  csv_line = (
206
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
207
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
208
- f"{generated_code},{argument_size},{output_size},{temp_size}\n"
194
+ f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
195
+ f'{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},'
196
+ f'{generated_code},{argument_size},{output_size},{temp_size}\n'
209
197
  )
210
198
 
211
- with open(csv_filename, "a") as f:
199
+ with open(csv_filename, 'a') as f:
212
200
  f.write(csv_line)
213
201
 
214
202
  param_dict = {
215
- "Function": function,
216
- "Precision": precision,
217
- "X": x,
218
- "Y": y,
219
- "Z": z,
220
- "PX": px,
221
- "PY": py,
222
- "Backend": backend,
223
- "Nodes": nodes,
203
+ 'Function': function,
204
+ 'Precision': precision,
205
+ 'X': x,
206
+ 'Y': y,
207
+ 'Z': z,
208
+ 'PX': px,
209
+ 'PY': py,
210
+ 'Backend': backend,
211
+ 'Nodes': nodes,
224
212
  }
225
213
  param_dict.update(extra_info)
226
214
  profiling_result = {
227
- "JIT Time": self.jit_time,
228
- "Min Time": min_time,
229
- "Max Time": max_time,
230
- "Mean Time": mean_time,
231
- "Std Time": std_time,
232
- "Last Time": last_time,
233
- "Generated Code": self._normalize_memory_units(generated_code),
234
- "Argument Size": self._normalize_memory_units(argument_size),
235
- "Output Size": self._normalize_memory_units(output_size),
236
- "Temporary Size": self._normalize_memory_units(temp_size),
215
+ 'JIT Time': self.jit_time,
216
+ 'Min Time': min_time,
217
+ 'Max Time': max_time,
218
+ 'Mean Time': mean_time,
219
+ 'Std Time': std_time,
220
+ 'Last Time': last_time,
221
+ 'Generated Code': self._normalize_memory_units(generated_code),
222
+ 'Argument Size': self._normalize_memory_units(argument_size),
223
+ 'Output Size': self._normalize_memory_units(output_size),
224
+ 'Temporary Size': self._normalize_memory_units(temp_size),
237
225
  }
238
226
  iteration_runs = {}
239
227
  for i in range(len(times_array)):
240
- iteration_runs[f"Run {i}"] = times_array[i]
228
+ iteration_runs[f'Run {i}'] = times_array[i]
241
229
 
242
- with open(md_filename, "w") as f:
243
- f.write(f"# Reporting for {function}\n")
244
- f.write(f"## Parameters\n")
230
+ with open(md_filename, 'w') as f:
231
+ f.write(f'# Reporting for {function}\n')
232
+ f.write('## Parameters\n')
245
233
  f.write(
246
234
  tabulate(
247
235
  param_dict.items(),
248
- headers=["Parameter", "Value"],
249
- tablefmt="github",
250
- ))
251
- f.write("\n---\n")
252
- f.write(f"## Profiling Data\n")
236
+ headers=['Parameter', 'Value'],
237
+ tablefmt='github',
238
+ )
239
+ )
240
+ f.write('\n---\n')
241
+ f.write('## Profiling Data\n')
253
242
  f.write(
254
243
  tabulate(
255
244
  profiling_result.items(),
256
- headers=["Parameter", "Value"],
257
- tablefmt="github",
258
- ))
259
- f.write("\n---\n")
260
- f.write(f"## Iteration Runs\n")
245
+ headers=['Parameter', 'Value'],
246
+ tablefmt='github',
247
+ )
248
+ )
249
+ f.write('\n---\n')
250
+ f.write('## Iteration Runs\n')
261
251
  f.write(
262
252
  tabulate(
263
253
  iteration_runs.items(),
264
- headers=["Iteration", "Time"],
265
- tablefmt="github",
266
- ))
254
+ headers=['Iteration', 'Time'],
255
+ tablefmt='github',
256
+ )
257
+ )
267
258
  if self.jax_fn and self.compile_info:
268
- f.write("\n---\n")
269
- f.write(f"## Compiled Code\n")
270
- f.write(f"```hlo\n")
271
- f.write(self.compiled_code["COMPILED"])
272
- f.write(f"\n```\n")
273
- f.write("\n---\n")
274
- f.write(f"## Lowered Code\n")
275
- f.write(f"```hlo\n")
276
- f.write(self.compiled_code["LOWERED"])
277
- f.write(f"\n```\n")
278
- f.write("\n---\n")
259
+ f.write('\n---\n')
260
+ f.write('## Compiled Code\n')
261
+ f.write('```hlo\n')
262
+ f.write(self.compiled_code['COMPILED'])
263
+ f.write('\n```\n')
264
+ f.write('\n---\n')
265
+ f.write('## Lowered Code\n')
266
+ f.write('```hlo\n')
267
+ f.write(self.compiled_code['LOWERED'])
268
+ f.write('\n```\n')
269
+ f.write('\n---\n')
279
270
  if self.save_jaxpr:
280
- f.write(f"## JAXPR\n")
281
- f.write(f"```haskel\n")
282
- f.write(self.compiled_code["JAXPR"])
283
- f.write(f"\n```\n")
271
+ f.write('## JAXPR\n')
272
+ f.write('```haskel\n')
273
+ f.write(self.compiled_code['JAXPR'])
274
+ f.write('\n```\n')
284
275
 
285
276
  # Reset the timer
286
277
  self.jit_time = 0.0