Skip to content

Commit 311629f

Browse files
committed
Feature support testing for IAllreduce and IReduce for GPU backends
1 parent 9620a4b commit 311629f

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

test/mpi_support_test.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
include("common.jl")
2+
3+
MPI.Init()
4+
5+
# Those MPI calls may be unsupported features (e.g. for GPU backends), and will raise SIGSEGV
6+
# (or a similar signal) when called, which cannot be handled in Julia in a portable way.
7+
8+
op = ARGS[1]
9+
if op == "IAllreduce"
10+
# IAllreduce is unsupported for CUDA with OpenMPI + UCX
11+
# See https://docs.open-mpi.org/en/main/tuning-apps/networking/cuda.html#which-mpi-apis-do-not-work-with-cuda-aware-ucx
12+
send_arr = ArrayType(zeros(Int, 1))
13+
recv_arr = ArrayType{Int}(undef, 1)
14+
synchronize()
15+
req = MPI.IAllreduce!(send_arr, recv_arr, +, MPI.COMM_WORLD)
16+
MPI.Wait(req)
17+
18+
elseif op == "IReduce"
19+
# IAllreduce is unsupported for CUDA with OpenMPI + UCX
20+
send_arr = ArrayType(zeros(Int, 1))
21+
recv_arr = ArrayType{Int}(undef, 1)
22+
synchronize()
23+
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=0)
24+
MPI.Wait(req)
25+
26+
else
27+
error("unknown test: $op")
28+
end

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ if Sys.isunix()
7171
include("mpiexecjl.jl")
7272
end
7373

74+
function is_mpi_operation_supported(mpi_op, n=nprocs)
75+
test_file = joinpath(@__DIR__, "mpi_support_test.jl")
76+
cmd = `$(mpiexec()) -n $n $(Base.julia_cmd()) --startup-file=no $test_file $mpi_op`
77+
supported = success(run(ignorestatus(cmd)))
78+
!supported && @warn "$mpi_op is unsupported with $backend_name"
79+
return supported
80+
end
81+
82+
if ArrayType != Array # we expect that only GPU backends can have unsupported features
83+
ENV["JULIA_MPI_TEST_IALLREDUCE"] = is_mpi_operation_supported("IAllreduce")
84+
ENV["JULIA_MPI_TEST_IREDUCE"] = is_mpi_operation_supported("IReduce")
85+
end
86+
7487
excludefiles = split(get(ENV,"JULIA_MPI_TEST_EXCLUDE",""),',')
7588

7689
testdir = @__DIR__

test/test_allreduce.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ else
1313
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
1414
end
1515

16+
iallreduce_supported = get(ENV, "JULIA_MPI_TEST_IALLREDUCE", "true") == "true"
17+
18+
1619
for T = [Int]
1720
for dims = [1, 2, 3]
1821
send_arr = ArrayType(zeros(T, Tuple(3 for i in 1:dims)))
@@ -46,16 +49,20 @@ for T = [Int]
4649

4750
# Nonblocking
4851
recv_arr = ArrayType{T}(undef, size(send_arr))
49-
req = MPI.IAllreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)
50-
MPI.Wait(req)
51-
@test recv_arr == comm_size .* send_arr
52+
if iallreduce_supported
53+
req = MPI.IAllreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)
54+
MPI.Wait(req)
55+
end
56+
@test recv_arr == comm_size .* send_arr skip=!iallreduce_supported
5257

5358
# Nonblocking (IN_PLACE)
5459
recv_arr = copy(send_arr)
5560
synchronize()
56-
req = MPI.IAllreduce!(recv_arr, op, MPI.COMM_WORLD)
57-
MPI.Wait(req)
58-
@test recv_arr == comm_size .* send_arr
61+
if iallreduce_supported
62+
req = MPI.IAllreduce!(recv_arr, op, MPI.COMM_WORLD)
63+
MPI.Wait(req)
64+
end
65+
@test recv_arr == comm_size .* send_arr skip=!iallreduce_supported
5966
end
6067
end
6168
end

test/test_reduce.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ const can_do_closures =
99
Sys.ARCH !== :aarch64 &&
1010
!startswith(string(Sys.ARCH), "arm")
1111

12+
ireduce_supported = get(ENV, "JULIA_MPI_TEST_IREDUCE", "true") == "true"
13+
1214
using DoubleFloats
1315

1416
MPI.Init()
@@ -119,18 +121,22 @@ for T = [Int]
119121

120122
# Nonblocking
121123
recv_arr = ArrayType{T}(undef, size(send_arr))
122-
req = MPI.IReduce!(send_arr, recv_arr, op, MPI.COMM_WORLD; root=root)
123-
MPI.Wait(req)
124+
if ireduce_supported
125+
req = MPI.IReduce!(send_arr, recv_arr, op, MPI.COMM_WORLD; root=root)
126+
MPI.Wait(req)
127+
end
124128
if isroot
125-
@test recv_arr == sz .* send_arr
129+
@test recv_arr == sz .* send_arr skip=!ireduce_supported
126130
end
127131

128132
# Nonblocking (IN_PLACE)
129133
recv_arr = copy(send_arr)
130-
req = MPI.IReduce!(recv_arr, op, MPI.COMM_WORLD; root=root)
131-
MPI.Wait(req)
134+
if ireduce_supported
135+
req = MPI.IReduce!(recv_arr, op, MPI.COMM_WORLD; root=root)
136+
MPI.Wait(req)
137+
end
132138
if isroot
133-
@test recv_arr == sz .* send_arr
139+
@test recv_arr == sz .* send_arr skip=!ireduce_supported
134140
end
135141
end
136142
end
@@ -148,10 +154,12 @@ else
148154
end
149155

150156
recv_arr = isroot ? zeros(eltype(send_arr), size(send_arr)) : nothing
151-
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=root)
152-
MPI.Wait(req)
157+
if ireduce_supported
158+
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=root)
159+
MPI.Wait(req)
160+
end
153161
if rank == root
154-
@test recv_arr [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
162+
@test recv_arr [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64) skip=!ireduce_supported
155163
end
156164

157165
MPI.Barrier( MPI.COMM_WORLD )

0 commit comments

Comments
 (0)