Skip to content

Commit b280ac6

Browse files
committed
update %timeit
1 parent 38889cb commit b280ac6

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

lectures/jax_intro.md

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -291,23 +291,47 @@ x = jnp.ones(n)
291291
How long does the function take to execute?
292292

293293
```{code-cell} ipython3
294-
%time f(x).block_until_ready()
294+
%timeit f(x).block_until_ready()
295295
```
296296

297297
```{note}
298298
Here, in order to measure actual speed, we use the `block_until_ready()` method
299299
to hold the interpreter until the results of the computation are returned from
300300
the device. This is necessary because JAX uses asynchronous dispatch, which
301301
allows the Python interpreter to run ahead of GPU computations.
302+
```
303+
304+
```{note}
305+
Here, we use the [`%timeit` magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-timeit)
306+
to time the execution of the function.
307+
308+
This command runs the code multiple times to get a more accurate measurement.
309+
310+
Alternatively, we could use the `%time` magic command:
311+
312+
```{code-cell} ipython3
313+
%time f(x).block_until_ready()
314+
```
315+
316+
Unlike `%timeit`, this command runs the code only once.
317+
318+
The `%timeit` magic command offers several advantages over `%time`:
302319

320+
* It executes the code multiple times, providing a more accurate measurement
321+
for measuring short code snippets
322+
* It reports both the average execution time and standard deviation
323+
324+
However, `%timeit` can be time-consuming for large computations.
325+
326+
This is why we switch to using `%time` in later lectures.
303327
```
304328
305329
The code doesn't run as fast as we might hope, given that it's running on a GPU.
306330
307331
But if we run it a second time it becomes much faster:
308332
309333
```{code-cell} ipython3
310-
%time f(x).block_until_ready()
334+
%timeit f(x).block_until_ready()
311335
```
312336

313337
This is because the built in functions like `jnp.cos` are JIT compiled and the
@@ -330,7 +354,7 @@ y = jnp.ones(m)
330354
```
331355

332356
```{code-cell} ipython3
333-
%time f(y).block_until_ready()
357+
%timeit f(y).block_until_ready()
334358
```
335359

336360
Notice that the execution time increases, because now new versions of
@@ -341,14 +365,14 @@ If we run again, the code is dispatched to the correct compiled version and we
341365
get faster execution.
342366

343367
```{code-cell} ipython3
344-
%time f(y).block_until_ready()
368+
%timeit f(y).block_until_ready()
345369
```
346370

347371
The compiled versions for the previous array size are still available in memory
348372
too, and the following call is dispatched to the correct compiled code.
349373

350374
```{code-cell} ipython3
351-
%time f(x).block_until_ready()
375+
%timeit f(x).block_until_ready()
352376
```
353377

354378
### Compiling the outer function
@@ -368,7 +392,7 @@ f_jit(x)
368392
And now let's time it.
369393

370394
```{code-cell} ipython3
371-
%time f_jit(x).block_until_ready()
395+
%timeit f_jit(x).block_until_ready()
372396
```
373397

374398
Note the speed gain.
@@ -523,7 +547,7 @@ z_loops = np.empty((n, n))
523547
```
524548

525549
```{code-cell} ipython3
526-
%%time
550+
%%timeit
527551
for i in range(n):
528552
for j in range(n):
529553
z_loops[i, j] = f(x[i], y[j])
@@ -564,14 +588,14 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
564588
Now we get what we want and the execution time is very fast.
565589

566590
```{code-cell} ipython3
567-
%%time
591+
%%timeit
568592
z_mesh = f(x_mesh, y_mesh).block_until_ready()
569593
```
570594

571595
Let's run again to eliminate compile time.
572596

573597
```{code-cell} ipython3
574-
%%time
598+
%%timeit
575599
z_mesh = f(x_mesh, y_mesh).block_until_ready()
576600
```
577601

@@ -591,14 +615,14 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
591615
```
592616

593617
```{code-cell} ipython3
594-
%%time
618+
%%timeit
595619
z_mesh = f(x_mesh, y_mesh).block_until_ready()
596620
```
597621

598622
Let's run again to get rid of compile time.
599623

600624
```{code-cell} ipython3
601-
%%time
625+
%%timeit
602626
z_mesh = f(x_mesh, y_mesh).block_until_ready()
603627
```
604628

@@ -637,14 +661,14 @@ f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
637661
With this construction, we can now call the function $f$ on flat (low memory) arrays.
638662

639663
```{code-cell} ipython3
640-
%%time
664+
%%timeit
641665
z_vmap = f_vec(x, y).block_until_ready()
642666
```
643667

644668
We run it again to eliminate compile time.
645669

646670
```{code-cell} ipython3
647-
%%time
671+
%%timeit
648672
z_vmap = f_vec(x, y).block_until_ready()
649673
```
650674

@@ -716,14 +740,14 @@ def compute_call_price_jax(β=β,
716740
Let's run it once to compile it:
717741

718742
```{code-cell} ipython3
719-
%%time
743+
%%timeit
720744
compute_call_price_jax().block_until_ready()
721745
```
722746

723747
And now let's time it:
724748

725749
```{code-cell} ipython3
726-
%%time
750+
%%timeit
727751
compute_call_price_jax().block_until_ready()
728752
```
729753

0 commit comments

Comments
 (0)