Skip to content

Commit d5f6640

Browse files
authored
Switch JAX experimental on IPU to new PjRt client. (#21)
The new IPU PjRt client supports: * Multi IPUs, with collectives (supported by GCL); * Asynchronous dispatch; * Slicing/low flops ops fallback on host/cpu backend;
1 parent 8aad5ab commit d5f6640

File tree

8 files changed

+99
-25
lines changed

8 files changed

+99
-25
lines changed

.github/workflows/jax-ci-ipu-internal.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
timeout-minutes: 10
2121
steps:
2222
- name: Cancel previous
23-
uses: styfle/cancel-workflow-action@0.10.0
23+
uses: styfle/cancel-workflow-action@0.11.0
2424
with:
2525
access_token: ${{ github.token }}
2626
if: ${{github.ref != 'refs/head/jax-v0.3.16-ipu'}}
@@ -67,9 +67,9 @@ jobs:
6767
python3 setup.py bdist_wheel --universal
6868
pip3 install dist/*.whl
6969
# Run IPU specific unit tests
70-
- name: Run JAX IPU unit tests
70+
- name: Run JAX IPU model unit tests
7171
run: |
72-
XLA_IPU_PLATFORM_DEVICE_COUNT=2 JAX_IPU_USE_MODEL=true JAX_IPU_MODEL_NUM_TILES=8 pytest --tb=short -vv --log-cli-level=INFO ./tests/ipu/
72+
JAX_IPU_DEVICE_COUNT=2 JAX_IPU_USE_MODEL=true JAX_IPU_MODEL_NUM_TILES=8 pytest --tb=short -vv --log-cli-level=INFO ./tests/ipu/
7373
# Dockerized workflow known to create issues with self-hosted servers.
7474
# Solution is to fully cleanup the workspace for the next action.
7575
# See: https://stackoverflow.com/questions/70483902/how-to-actually-clean-up-the-repository-on-self-hosted-runner-after-github-actio

.github/workflows/jax-ci-ipu-public.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
timeout-minutes: 10
2121
steps:
2222
- name: Cancel previous
23-
uses: styfle/cancel-workflow-action@0.10.0
23+
uses: styfle/cancel-workflow-action@0.11.0
2424
with:
2525
access_token: ${{ github.token }}
2626
if: ${{github.ref != 'refs/head/jax-v0.3.16-ipu'}}
@@ -68,7 +68,7 @@ jobs:
6868
# Run IPU specific unit tests
6969
- name: Run JAX IPU unit tests
7070
run: |
71-
XLA_IPU_PLATFORM_DEVICE_COUNT=2 JAX_IPU_USE_MODEL=true JAX_IPU_MODEL_NUM_TILES=8 pytest --tb=short -vv --log-cli-level=INFO ./tests/ipu/
71+
JAX_IPU_DEVICE_COUNT=2 JAX_IPU_USE_MODEL=true JAX_IPU_MODEL_NUM_TILES=8 pytest --tb=short -vv --log-cli-level=INFO ./tests/ipu/
7272
# Dockerized workflow known to create issues with self-hosted servers.
7373
# Solution is to fully cleanup the workspace for the next action.
7474
# See: https://stackoverflow.com/questions/70483902/how-to-actually-clean-up-the-repository-on-self-hosted-runner-after-github-actio

ipu/docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Note: the build process will work with more recent versions of NumPy, but that w
1313
Building `jaxlib` currently requires the branch `jax-v0.3.15-ipu`. Once the branch checked out, the build process is similar to other backends:
1414
```bash
1515
export TF_POPLAR_BASE=#...poplar install directory
16-
python build/build.py --bazel_options=--override_repository=org_tensorflow=PATH/tensorflow-jax-experimental
16+
python build/build.py --enable_ipu --bazel_options=--override_repository=org_tensorflow=PATH/tensorflow-jax-experimental
1717
```
1818
The `override_repository` config is optional. By default, the build process will pull the experimental IPU TensorFlow XLA code from the repository https://github.com/graphcore-research/tensorflow-jax-experimental.
1919

jax/_src/lib/xla_bridge.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,19 @@ def ipu_flag_update_hook(name: str, value: Any):
129129
update_hook=ipu_model_tiles_update_hook)
130130
flags.DEFINE_bool(
131131
'jax_ipu_use_legacy_client',
132-
bool_env('JAX_IPU_LEGACY_CLIENT', True),
132+
bool_env('JAX_IPU_USE_LEGACY_CLIENT', False),
133133
'Use legacy IPU PjRt client, not supporting multiple IPUs.',
134134
update_hook=partial(ipu_flag_update_hook, 'jax_ipu_use_legacy_client'))
135+
flags.DEFINE_integer(
136+
'jax_ipu_device_count',
137+
int_env('JAX_IPU_DEVICE_COUNT', -1),
138+
'Number of IPUs attached to JAX.',
139+
update_hook=partial(ipu_flag_update_hook, 'jax_ipu_device_count'))
140+
flags.DEFINE_string(
141+
'jax_ipu_visible_devices',
142+
os.getenv('JAX_IPU_VISIBLE_DEVICES', '').lower(),
143+
'Specific IPUs visible and attached to JAX.',
144+
update_hook=partial(ipu_flag_update_hook, 'jax_ipu_visible_devices'))
135145

136146
def get_compile_options(
137147
num_replicas: int,

tests/ipu/basics_test.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from jax._src import test_util as jtu
1717
from unittest import SkipTest
1818
from functools import partial
19+
import time
1920

2021
import os
2122
import numpy as np
@@ -24,16 +25,24 @@
2425

2526
from jax import lax
2627
from jax.config import config
28+
from jaxlib.ipu_xla_client import IpuPjRtDevice
29+
30+
# Skipping tests on legacy IPU backend.
31+
is_ipu_legacy_backend = not isinstance(jax.devices("ipu")[0], IpuPjRtDevice)
2732

2833

2934
class IpuBasicsTest(jtu.JaxTestCase):
35+
def setUp(self):
36+
super().setUp()
37+
self.is_ipu_model = config.FLAGS.jax_ipu_use_model
3038

3139
def test_device_count(self):
32-
expected = os.getenv('XLA_IPU_PLATFORM_DEVICE_COUNT')
40+
expected = os.getenv('JAX_IPU_DEVICE_COUNT')
3341
expected = int(expected) if expected else 1
34-
35-
assert jax.device_count(backend='ipu') == expected
36-
assert len(jax.devices("ipu")) == expected
42+
# Only testing on IPU model for device count.
43+
if self.is_ipu_model:
44+
assert jax.device_count(backend='ipu') == expected
45+
assert len(jax.devices("ipu")) == expected
3746

3847
def test_default_backend(self):
3948
config.FLAGS.jax_platform_name = 'ipu'
@@ -67,6 +76,31 @@ def add(a,b):
6776
c = jit_add(a, b)
6877
self.assertAllClose(c, a + b)
6978

79+
def test_asynchronous_backend(self):
80+
@partial(jax.jit, backend="ipu")
81+
def fn(x):
82+
y = x * x
83+
return y * (y - x)
84+
85+
N = 1000000
86+
x = np.arange(N).astype(np.float32)
87+
# First dummy call to compile.
88+
fn(x)
89+
90+
num_iters = 10
91+
start = time.perf_counter()
92+
for _ in range(num_iters):
93+
x = fn(x)
94+
async_timing = time.perf_counter() - start
95+
x.block_until_ready()
96+
97+
start = time.perf_counter()
98+
for _ in range(num_iters):
99+
x = fn(x).block_until_ready()
100+
block_timing = time.perf_counter() - start
101+
102+
# At least 10x faster without blocking.
103+
self.assertLessEqual(async_timing * 10, block_timing)
70104

71105
def test_lax_argmin_argmax(self):
72106
@partial(jax.jit, backend="ipu")

tests/ipu/donate_argnums_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import jax
2222
import numpy as np
2323

24-
class DonateArgnumsTest(jtu.JaxTestCase):
24+
class IpuDonateArgnumsTest(jtu.JaxTestCase):
2525

2626
def testSingleDonateBufferFirstArgument(self):
2727

tests/ipu/infeed_outfeed_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import threading
17+
import unittest
1718

1819
from absl.testing import absltest
1920
from jax._src import test_util as jtu
@@ -25,11 +26,19 @@
2526
from jax.lib import xla_client
2627
import numpy as np
2728

29+
from jaxlib.ipu_xla_client import IpuPjRtDevice
30+
2831
config.parse_flags_with_absl()
2932
FLAGS = config.FLAGS
3033

31-
class InfeedTest(jtu.JaxTestCase):
3234

35+
# Skipping tests on new IPU backend.
36+
is_ipu_legacy_backend = not isinstance(jax.devices("ipu")[0], IpuPjRtDevice)
37+
38+
39+
class IpuInfeedTest(jtu.JaxTestCase):
40+
41+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
3342
def testInfeed(self):
3443

3544
@jax.jit
@@ -45,6 +54,7 @@ def f(x):
4554
device.transfer_to_infeed((y,))
4655
self.assertAllClose(f(x), x + y)
4756

57+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
4858
def testInfeedPytree(self):
4959

5060
x = np.float32(1.5)
@@ -64,6 +74,7 @@ def f(x):
6474
device.transfer_to_infeed(tuple(flat_to_infeed))
6575
self.assertAllClose(f(x), to_infeed)
6676

77+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
6778
def testOutfeed(self):
6879
hcb.stop_outfeed_receiver()
6980

@@ -86,6 +97,7 @@ def f(x):
8697
device.transfer_from_outfeed(
8798
xla_client.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
8899

100+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
89101
def testOutfeedPytree(self):
90102
hcb.stop_outfeed_receiver()
91103

@@ -109,6 +121,7 @@ def f(x, y):
109121
device.transfer_from_outfeed(
110122
xla_client.shape_from_pyval((x, y)).with_major_to_minor_layout_if_absent())
111123

124+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
112125
def testInfeedThenOutfeed(self):
113126
hcb.stop_outfeed_receiver()
114127

@@ -131,6 +144,7 @@ def f(x):
131144
execution.join()
132145
self.assertAllClose(out, y + np.float32(1))
133146

147+
@unittest.skipUnless(is_ipu_legacy_backend, "Infeed/outfeed not yet supported on IPU.")
134148
def testInfeedThenOutfeedInALoop(self):
135149
hcb.stop_outfeed_receiver()
136150

tests/ipu/multi_device_test.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import unittest
1516
from absl.testing import absltest
1617
from jax._src import test_util as jtu
17-
from unittest import SkipTest
18+
from functools import partial
1819

1920
import jax
2021
from jax import jit
2122
import jax.numpy as jnp
22-
from jax.config import config
23+
from jaxlib.ipu_xla_client import IpuTargetType
2324

2425
import numpy as np
2526

27+
ipu_num_devices = len(jax.devices("ipu"))
28+
is_ipu_model = len(jax.devices("ipu")) > 0 and jax.devices("ipu")[0].target_type == IpuTargetType.IPU_MODEL
2629

27-
# Set XLA_IPU_PLATFORM_DEVICE_COUNT=N os env to attach multiple IPUs
28-
class MultiDeviceTest(jtu.JaxTestCase):
29-
def test_jit_with_multi_devices(self):
30-
config.FLAGS.jax_platform_name = 'ipu'
31-
self.assertEqual(jax.default_backend(), 'ipu')
3230

33-
devices = jax.devices()
34-
if len(jax.devices()) < 2:
35-
raise SkipTest("IPU test requires multiple devices")
31+
# Set JAX_IPU_DEVICE_COUNT=N os env to attach multiple IPUs
32+
class IpuMultiDeviceTest(jtu.JaxTestCase):
33+
@unittest.skipIf(ipu_num_devices < 2, "Requires multiple IPU devices")
34+
def test_jit_on_multi_devices(self):
35+
self.assertEqual(jax.default_backend(), 'ipu')
36+
ipu_devices = jax.devices()
3637

3738
def func(x, w, b):
3839
return jnp.matmul(w, x) + b
@@ -41,11 +42,26 @@ def func(x, w, b):
4142
w = np.random.normal(size=[3, 2])
4243
b = np.random.normal(size=[3, 3])
4344

44-
for i in range(len(devices)):
45-
jit_func = jit(func, device=devices[i])
45+
# Just testing on 2 IPU devices, to reduce compilation time.
46+
# TODO: support portable executable?
47+
for i in range(2):
48+
jit_func = jit(func, device=ipu_devices[i])
4649
r = jit_func(x, w, b)
4750
self.assertAllClose(r, w @ x + b)
4851

52+
@unittest.skipIf(ipu_num_devices < 2 or is_ipu_model, "Requires multiple IPU hardware devices")
53+
def test_pmap_simple_reduce(self):
54+
N = 3
55+
data = np.arange(2 * N, dtype=np.float32).reshape((-1, N))
56+
57+
@partial(jax.pmap, axis_name='i', donate_argnums=(1,), backend="ipu")
58+
def parallel_fn(x, y):
59+
z = x + jax.lax.psum(y, 'i')
60+
return z
61+
62+
output = parallel_fn(data**2, data)
63+
self.assertAllClose(output, data**2 + np.sum(data, axis=0))
64+
4965

5066
if __name__ == "__main__":
5167
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)