@@ -291,23 +291,47 @@ x = jnp.ones(n)
291
291
How long does the function take to execute?
292
292
293
293
``` {code-cell} ipython3
294
- %time f(x).block_until_ready()
294
+ %timeit f(x).block_until_ready()
295
295
```
296
296
297
297
``` {note}
298
298
Here, in order to measure actual speed, we use the `block_until_ready()` method
299
299
to hold the interpreter until the results of the computation are returned from
300
300
the device. This is necessary because JAX uses asynchronous dispatch, which
301
301
allows the Python interpreter to run ahead of GPU computations.
302
+ ```
303
+
304
+ ``` {note}
305
+ Here, we use the [`%timeit` magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-timeit)
306
+ to time the execution of the function.
307
+
308
+ This command runs the code multiple times to get a more accurate measurement.
309
+
310
+ Alternatively, we could use the `%time` magic command:
311
+
312
+ ```{code-cell} ipython3
313
+ %time f(x).block_until_ready()
314
+ ```
315
+
316
+ Unlike ` %timeit ` , this command runs the code only once.
317
+
318
+ The ` %timeit ` magic command offers several advantages over ` %time ` :
302
319
320
+ * It executes the code multiple times, providing a more accurate measurement
321
+ for measuring short code snippets
322
+ * It reports both the average execution time and standard deviation
323
+
324
+ However, ` %timeit ` can be time-consuming for large computations.
325
+
326
+ This is why we switch to using ` %time ` in later lectures.
303
327
```
304
328
305
329
The code doesn't run as fast as we might hope, given that it's running on a GPU.
306
330
307
331
But if we run it a second time it becomes much faster:
308
332
309
333
```{code-cell} ipython3
310
- %time f(x).block_until_ready()
334
+ %timeit f(x).block_until_ready()
311
335
```
312
336
313
337
This is because the built in functions like ` jnp.cos ` are JIT compiled and the
@@ -330,7 +354,7 @@ y = jnp.ones(m)
330
354
```
331
355
332
356
``` {code-cell} ipython3
333
- %time f(y).block_until_ready()
357
+ %timeit f(y).block_until_ready()
334
358
```
335
359
336
360
Notice that the execution time increases, because now new versions of
@@ -341,14 +365,14 @@ If we run again, the code is dispatched to the correct compiled version and we
341
365
get faster execution.
342
366
343
367
``` {code-cell} ipython3
344
- %time f(y).block_until_ready()
368
+ %timeit f(y).block_until_ready()
345
369
```
346
370
347
371
The compiled versions for the previous array size are still available in memory
348
372
too, and the following call is dispatched to the correct compiled code.
349
373
350
374
``` {code-cell} ipython3
351
- %time f(x).block_until_ready()
375
+ %timeit f(x).block_until_ready()
352
376
```
353
377
354
378
### Compiling the outer function
@@ -368,7 +392,7 @@ f_jit(x)
368
392
And now let's time it.
369
393
370
394
``` {code-cell} ipython3
371
- %time f_jit(x).block_until_ready()
395
+ %timeit f_jit(x).block_until_ready()
372
396
```
373
397
374
398
Note the speed gain.
@@ -523,7 +547,7 @@ z_loops = np.empty((n, n))
523
547
```
524
548
525
549
``` {code-cell} ipython3
526
- %%time
550
+ %%timeit
527
551
for i in range(n):
528
552
for j in range(n):
529
553
z_loops[i, j] = f(x[i], y[j])
@@ -564,14 +588,14 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
564
588
Now we get what we want and the execution time is very fast.
565
589
566
590
``` {code-cell} ipython3
567
- %%time
591
+ %%timeit
568
592
z_mesh = f(x_mesh, y_mesh).block_until_ready()
569
593
```
570
594
571
595
Let's run again to eliminate compile time.
572
596
573
597
``` {code-cell} ipython3
574
- %%time
598
+ %%timeit
575
599
z_mesh = f(x_mesh, y_mesh).block_until_ready()
576
600
```
577
601
@@ -591,14 +615,14 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
591
615
```
592
616
593
617
``` {code-cell} ipython3
594
- %%time
618
+ %%timeit
595
619
z_mesh = f(x_mesh, y_mesh).block_until_ready()
596
620
```
597
621
598
622
Let's run again to get rid of compile time.
599
623
600
624
``` {code-cell} ipython3
601
- %%time
625
+ %%timeit
602
626
z_mesh = f(x_mesh, y_mesh).block_until_ready()
603
627
```
604
628
@@ -637,14 +661,14 @@ f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
637
661
With this construction, we can now call the function $f$ on flat (low memory) arrays.
638
662
639
663
``` {code-cell} ipython3
640
- %%time
664
+ %%timeit
641
665
z_vmap = f_vec(x, y).block_until_ready()
642
666
```
643
667
644
668
We run it again to eliminate compile time.
645
669
646
670
``` {code-cell} ipython3
647
- %%time
671
+ %%timeit
648
672
z_vmap = f_vec(x, y).block_until_ready()
649
673
```
650
674
@@ -716,14 +740,14 @@ def compute_call_price_jax(β=β,
716
740
Let's run it once to compile it:
717
741
718
742
``` {code-cell} ipython3
719
- %%time
743
+ %%timeit
720
744
compute_call_price_jax().block_until_ready()
721
745
```
722
746
723
747
And now let's time it:
724
748
725
749
``` {code-cell} ipython3
726
- %%time
750
+ %%timeit
727
751
compute_call_price_jax().block_until_ready()
728
752
```
729
753
0 commit comments