Skip to content

Commit 4489a68

Browse files
committed
TEST: Parametrize all DerivativesDataSink tests
1 parent b297129 commit 4489a68

File tree

1 file changed

+94
-47
lines changed

1 file changed

+94
-47
lines changed

niworkflows/interfaces/tests/test_bids.py

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,35 @@
4747
BOLD_PATH = "ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz"
4848

4949

50+
def make_prep_and_save(
51+
prep_interface,
52+
base_directory,
53+
out_path_base=None,
54+
**kwargs,
55+
):
56+
prep = save = prep_interface(
57+
**kwargs,
58+
**({"out_path_base": out_path_base} if prep_interface == bintfs.DerivativesDataSink else {}),
59+
)
60+
if prep_interface is bintfs.DerivativesDataSink:
61+
prep.inputs.base_directory = base_directory
62+
else:
63+
save = bintfs.SaveDerivative(base_directory=base_directory)
64+
65+
return prep, save
66+
67+
68+
def connect_and_run_save(prep_result, save):
69+
if prep_result.interface is bintfs.DerivativesDataSink:
70+
return prep_result
71+
72+
save.inputs.in_file = prep_result.outputs.out_file
73+
save.inputs.relative_path = prep_result.outputs.out_path
74+
save.inputs.metadata = prep_result.outputs.out_meta
75+
76+
return save.run()
77+
78+
5079
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
5180
@pytest.mark.parametrize("out_path_base", [None, "fmriprep"])
5281
@pytest.mark.parametrize(
@@ -296,36 +325,26 @@ def test_DerivativesDataSink_build_path(
296325
ds_inputs.append(str(fname))
297326

298327
base_directory = tmp_path / "output"
299-
work_dir = tmp_path / "work"
300328
base_directory.mkdir()
301-
work_dir.mkdir()
302329

303-
prep = save = interface(
330+
prep, save = make_prep_and_save(
331+
interface,
332+
base_directory=str(base_directory),
333+
out_path_base=out_path_base,
304334
in_file=ds_inputs,
305335
source_file=source,
306336
dismiss_entities=dismiss_entities,
307337
**entities,
308-
**({"out_path_base": out_path_base} if interface == bintfs.DerivativesDataSink else {}),
309338
)
310-
if interface == bintfs.DerivativesDataSink:
311-
prep.inputs.base_directory = str(base_directory)
312-
else:
313-
save = bintfs.SaveDerivative(base_directory=str(base_directory))
314-
315339
if isinstance(expectation, type):
316340
with pytest.raises(expectation):
317341
prep.run()
318342
return
319343

320-
prep_outputs = save_outputs = prep.run().outputs
321-
322-
if save is not prep:
323-
save.inputs.in_file = prep_outputs.out_file
324-
save.inputs.relative_path = prep_outputs.out_path
325-
save.inputs.metadata = prep_outputs.out_meta
326-
save_outputs = save.run().outputs
344+
prep_result = prep.run()
345+
save_result = connect_and_run_save(prep_result, save)
327346

328-
output = save_outputs.out_file
347+
output = save_result.outputs.out_file
329348
if isinstance(expectation, str):
330349
expectation = [expectation]
331350
output = [output]
@@ -363,7 +382,8 @@ def test_DerivativesDataSink_build_path(
363382
assert sha1(Path(out).read_bytes()).hexdigest() == chksum
364383

365384

366-
def test_DerivativesDataSink_dtseries_json(tmp_path):
385+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
386+
def test_DerivativesDataSink_dtseries_json(tmp_path, interface):
367387
cifti_fname = str(tmp_path / "test.dtseries.nii")
368388

369389
axes = (nb.cifti2.SeriesAxis(start=0, step=2, size=20),
@@ -378,20 +398,22 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
378398
source_file.parent.mkdir(parents=True)
379399
source_file.touch()
380400

381-
dds = bintfs.DerivativesDataSink(
382-
in_file=cifti_fname,
401+
prep, save = make_prep_and_save(
402+
interface,
383403
base_directory=str(tmp_path),
404+
out_path_base="",
405+
in_file=cifti_fname,
384406
source_file=str(source_file),
385407
compress=False,
386-
out_path_base="",
387408
space="fsLR",
388409
grayordinates="91k",
389410
RepetitionTime=2.0,
390411
)
391412

392-
res = dds.run()
413+
prep_result = prep.run()
414+
save_result = connect_and_run_save(prep_result, save)
393415

394-
out_path = Path(res.outputs.out_file)
416+
out_path = Path(save_result.outputs.out_file)
395417

396418
assert out_path.name == "sub-01_task-rest_space-fsLR_bold.dtseries.nii"
397419
old_sidecar = out_path.with_name("sub-01_task-rest_space-fsLR_bold.dtseries.json")
@@ -402,6 +424,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
402424
assert "RepetitionTime" in json.loads(new_sidecar.read_text())
403425

404426

427+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
405428
@pytest.mark.parametrize(
406429
"space, size, units, xcodes, zipped, fixed, data_dtype",
407430
[
@@ -427,7 +450,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
427450
],
428451
)
429452
def test_DerivativesDataSink_bold(
430-
tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype
453+
tmp_path, interface, space, size, units, xcodes, zipped, fixed, data_dtype
431454
):
432455
fname = str(tmp_path / "source.nii") + (".gz" if zipped else "")
433456

@@ -438,25 +461,29 @@ def test_DerivativesDataSink_bold(
438461
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)
439462

440463
# BOLD derivative in T1w space
441-
dds = bintfs.DerivativesDataSink(
464+
prep, _ = make_prep_and_save(
465+
interface,
442466
base_directory=str(tmp_path),
443467
keep_dtype=True,
444468
data_dtype=data_dtype or Undefined,
445469
desc="preproc",
446470
source_file=BOLD_PATH,
447471
space=space or Undefined,
448472
in_file=fname,
449-
).run()
473+
)
450474

451-
nii = nb.load(dds.outputs.out_file)
452-
assert dds.outputs.fixed_hdr == fixed
475+
prep_result = prep.run()
476+
477+
nii = nb.load(prep_result.outputs.out_file)
478+
assert prep_result.outputs.fixed_hdr == fixed
453479
if data_dtype:
454480
assert nii.get_data_dtype() == np.dtype(data_dtype)
455481
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
456482
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
457483
assert nii.header.get_xyzt_units() == ("mm", "sec")
458484

459485

486+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
460487
@pytest.mark.parametrize(
461488
"space, size, units, xcodes, fixed",
462489
[
@@ -480,7 +507,7 @@ def test_DerivativesDataSink_bold(
480507
(None, (30, 30, 30), (None, "sec"), (0, 0), [True]),
481508
],
482509
)
483-
def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
510+
def test_DerivativesDataSink_t1w(tmp_path, interface, space, size, units, xcodes, fixed):
484511
fname = str(tmp_path / "source.nii.gz")
485512

486513
hdr = nb.Nifti1Header()
@@ -490,22 +517,26 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
490517
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)
491518

492519
# BOLD derivative in T1w space
493-
dds = bintfs.DerivativesDataSink(
520+
prep, _ = make_prep_and_save(
521+
interface,
494522
base_directory=str(tmp_path),
495523
keep_dtype=True,
496524
desc="preproc",
497525
source_file=T1W_PATH,
498526
space=space or Undefined,
499527
in_file=fname,
500-
).run()
528+
)
529+
530+
prep_result = prep.run()
501531

502-
nii = nb.load(dds.outputs.out_file)
503-
assert dds.outputs.fixed_hdr == fixed
532+
nii = nb.load(prep_result.outputs.out_file)
533+
assert prep_result.outputs.fixed_hdr == fixed
504534
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
505535
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
506536
assert nii.header.get_xyzt_units() == ("mm", "unknown")
507537

508538

539+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
509540
@pytest.mark.parametrize(
510541
"source_file",
511542
[
@@ -517,7 +548,7 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
517548
@pytest.mark.parametrize("source_dtype", ["<i4", "<f4"])
518549
@pytest.mark.parametrize("in_dtype", ["<i4", "<f4"])
519550
def test_DerivativesDataSink_data_dtype_source(
520-
tmp_path, source_file, source_dtype, in_dtype
551+
tmp_path, interface, source_file, source_dtype, in_dtype
521552
):
522553

523554
def make_empty_nii_with_dtype(fname, dtype):
@@ -539,19 +570,23 @@ def make_empty_nii_with_dtype(fname, dtype):
539570
for s in source_file:
540571
make_empty_nii_with_dtype(s, source_dtype)
541572

542-
dds = bintfs.DerivativesDataSink(
573+
prep, save = make_prep_and_save(
574+
interface,
543575
base_directory=str(tmp_path),
544576
data_dtype="source",
545577
desc="preproc",
546578
source_file=source_file,
547579
in_file=in_file,
548-
).run()
580+
)
581+
582+
prep_result = prep.run()
549583

550-
nii = nb.load(dds.outputs.out_file)
584+
nii = nb.load(prep_result.outputs.out_file)
551585
assert nii.get_data_dtype() == np.dtype(source_dtype)
552586

553587

554-
def test_DerivativesDataSink_fmapid(tmp_path):
588+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
589+
def test_DerivativesDataSink_fmapid(tmp_path, interface):
555590
"""Ascertain #637 is not regressing."""
556591
source_file = [
557592
(tmp_path / s)
@@ -569,7 +604,8 @@ def test_DerivativesDataSink_fmapid(tmp_path):
569604
in_file = tmp_path / "report.svg"
570605
in_file.write_text("")
571606

572-
dds = bintfs.DerivativesDataSink(
607+
prep, save = make_prep_and_save(
608+
interface,
573609
base_directory=str(tmp_path),
574610
datatype="figures",
575611
suffix="fieldmap",
@@ -579,12 +615,17 @@ def test_DerivativesDataSink_fmapid(tmp_path):
579615
fmapid="auto00000",
580616
source_file=[str(s.absolute()) for s in source_file],
581617
in_file=str(in_file),
582-
).run()
583-
assert dds.outputs.out_file.endswith("sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg")
618+
)
619+
620+
prep_result = prep.run()
621+
save_result = connect_and_run_save(prep_result, save)
622+
623+
assert save_result.outputs.out_file.endswith("sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg")
584624

585625

626+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
586627
@pytest.mark.parametrize("dtype", ("i2", "u2", "f4"))
587-
def test_DerivativesDataSink_values(tmp_path, dtype):
628+
def test_DerivativesDataSink_values(tmp_path, interface, dtype):
588629
# We use static checksums above, which ensures we don't break things, but
589630
# pins the tests to specific values.
590631
# Here we use random values, check that the values are preserved, and then
@@ -599,16 +640,19 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
599640
orig_data = np.asanyarray(nb.load(fname).dataobj)
600641
expected = np.rint(orig_data) if dtype[0] in "iu" else orig_data
601642

602-
dds = bintfs.DerivativesDataSink(
643+
prep, _ = make_prep_and_save(
644+
interface,
603645
base_directory=str(tmp_path),
604646
keep_dtype=True,
605647
data_dtype=dtype,
606648
desc="preproc",
607649
source_file=T1W_PATH,
608650
in_file=fname,
609-
).run()
651+
)
652+
653+
prep_result = prep.run()
610654

611-
out_file = Path(dds.outputs.out_file)
655+
out_file = Path(prep_result.outputs.out_file)
612656

613657
nii = nb.load(out_file)
614658
assert np.allclose(nii.dataobj, expected)
@@ -617,14 +661,17 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
617661
out_file.unlink()
618662

619663
# Rerun to ensure determinism with non-zero data
620-
dds = bintfs.DerivativesDataSink(
664+
prep, _ = make_prep_and_save(
665+
interface,
621666
base_directory=str(tmp_path),
622667
keep_dtype=True,
623668
data_dtype=dtype,
624669
desc="preproc",
625670
source_file=T1W_PATH,
626671
in_file=fname,
627-
).run()
672+
)
673+
674+
prep_result = prep.run()
628675

629676
assert sha1(out_file.read_bytes()).hexdigest() == checksum
630677

0 commit comments

Comments
 (0)