Skip to content

Commit 8c14cbe

Browse files
committed
TEST: Parametrize all DerivativesDataSink tests
1 parent b297129 commit 8c14cbe

File tree

1 file changed

+95
-47
lines changed

1 file changed

+95
-47
lines changed

niworkflows/interfaces/tests/test_bids.py

Lines changed: 95 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,34 @@
4747
BOLD_PATH = "ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz"
4848

4949

50+
def make_prep_and_save(
51+
prep_interface,
52+
base_directory,
53+
out_path_base=None,
54+
**kwargs,
55+
):
56+
if prep_interface is bintfs.DerivativesDataSink:
57+
kwargs.update(out_path_base=out_path_base, base_directory=base_directory)
58+
59+
prep = save = prep_interface(**kwargs)
60+
61+
if prep_interface is not bintfs.DerivativesDataSink:
62+
save = bintfs.SaveDerivative(base_directory=base_directory)
63+
64+
return prep, save
65+
66+
67+
def connect_and_run_save(prep_result, save):
68+
if prep_result.interface is bintfs.DerivativesDataSink:
69+
return prep_result
70+
71+
save.inputs.in_file = prep_result.outputs.out_file
72+
save.inputs.relative_path = prep_result.outputs.out_path
73+
save.inputs.metadata = prep_result.outputs.out_meta
74+
75+
return save.run()
76+
77+
5078
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
5179
@pytest.mark.parametrize("out_path_base", [None, "fmriprep"])
5280
@pytest.mark.parametrize(
@@ -296,36 +324,26 @@ def test_DerivativesDataSink_build_path(
296324
ds_inputs.append(str(fname))
297325

298326
base_directory = tmp_path / "output"
299-
work_dir = tmp_path / "work"
300327
base_directory.mkdir()
301-
work_dir.mkdir()
302328

303-
prep = save = interface(
329+
prep, save = make_prep_and_save(
330+
interface,
331+
base_directory=str(base_directory),
332+
out_path_base=out_path_base,
304333
in_file=ds_inputs,
305334
source_file=source,
306335
dismiss_entities=dismiss_entities,
307336
**entities,
308-
**({"out_path_base": out_path_base} if interface == bintfs.DerivativesDataSink else {}),
309337
)
310-
if interface == bintfs.DerivativesDataSink:
311-
prep.inputs.base_directory = str(base_directory)
312-
else:
313-
save = bintfs.SaveDerivative(base_directory=str(base_directory))
314-
315338
if isinstance(expectation, type):
316339
with pytest.raises(expectation):
317340
prep.run()
318341
return
319342

320-
prep_outputs = save_outputs = prep.run().outputs
321-
322-
if save is not prep:
323-
save.inputs.in_file = prep_outputs.out_file
324-
save.inputs.relative_path = prep_outputs.out_path
325-
save.inputs.metadata = prep_outputs.out_meta
326-
save_outputs = save.run().outputs
343+
prep_result = prep.run()
344+
save_result = connect_and_run_save(prep_result, save)
327345

328-
output = save_outputs.out_file
346+
output = save_result.outputs.out_file
329347
if isinstance(expectation, str):
330348
expectation = [expectation]
331349
output = [output]
@@ -363,7 +381,8 @@ def test_DerivativesDataSink_build_path(
363381
assert sha1(Path(out).read_bytes()).hexdigest() == chksum
364382

365383

366-
def test_DerivativesDataSink_dtseries_json(tmp_path):
384+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
385+
def test_DerivativesDataSink_dtseries_json(tmp_path, interface):
367386
cifti_fname = str(tmp_path / "test.dtseries.nii")
368387

369388
axes = (nb.cifti2.SeriesAxis(start=0, step=2, size=20),
@@ -378,20 +397,22 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
378397
source_file.parent.mkdir(parents=True)
379398
source_file.touch()
380399

381-
dds = bintfs.DerivativesDataSink(
382-
in_file=cifti_fname,
400+
prep, save = make_prep_and_save(
401+
interface,
383402
base_directory=str(tmp_path),
403+
out_path_base="",
404+
in_file=cifti_fname,
384405
source_file=str(source_file),
385406
compress=False,
386-
out_path_base="",
387407
space="fsLR",
388408
grayordinates="91k",
389409
RepetitionTime=2.0,
390410
)
391411

392-
res = dds.run()
412+
prep_result = prep.run()
413+
save_result = connect_and_run_save(prep_result, save)
393414

394-
out_path = Path(res.outputs.out_file)
415+
out_path = Path(save_result.outputs.out_file)
395416

396417
assert out_path.name == "sub-01_task-rest_space-fsLR_bold.dtseries.nii"
397418
old_sidecar = out_path.with_name("sub-01_task-rest_space-fsLR_bold.dtseries.json")
@@ -402,6 +423,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
402423
assert "RepetitionTime" in json.loads(new_sidecar.read_text())
403424

404425

426+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
405427
@pytest.mark.parametrize(
406428
"space, size, units, xcodes, zipped, fixed, data_dtype",
407429
[
@@ -427,7 +449,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
427449
],
428450
)
429451
def test_DerivativesDataSink_bold(
430-
tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype
452+
tmp_path, interface, space, size, units, xcodes, zipped, fixed, data_dtype
431453
):
432454
fname = str(tmp_path / "source.nii") + (".gz" if zipped else "")
433455

@@ -438,25 +460,29 @@ def test_DerivativesDataSink_bold(
438460
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)
439461

440462
# BOLD derivative in T1w space
441-
dds = bintfs.DerivativesDataSink(
463+
prep, _ = make_prep_and_save(
464+
interface,
442465
base_directory=str(tmp_path),
443466
keep_dtype=True,
444467
data_dtype=data_dtype or Undefined,
445468
desc="preproc",
446469
source_file=BOLD_PATH,
447470
space=space or Undefined,
448471
in_file=fname,
449-
).run()
472+
)
473+
474+
prep_result = prep.run()
450475

451-
nii = nb.load(dds.outputs.out_file)
452-
assert dds.outputs.fixed_hdr == fixed
476+
nii = nb.load(prep_result.outputs.out_file)
477+
assert prep_result.outputs.fixed_hdr == fixed
453478
if data_dtype:
454479
assert nii.get_data_dtype() == np.dtype(data_dtype)
455480
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
456481
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
457482
assert nii.header.get_xyzt_units() == ("mm", "sec")
458483

459484

485+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
460486
@pytest.mark.parametrize(
461487
"space, size, units, xcodes, fixed",
462488
[
@@ -480,7 +506,7 @@ def test_DerivativesDataSink_bold(
480506
(None, (30, 30, 30), (None, "sec"), (0, 0), [True]),
481507
],
482508
)
483-
def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
509+
def test_DerivativesDataSink_t1w(tmp_path, interface, space, size, units, xcodes, fixed):
484510
fname = str(tmp_path / "source.nii.gz")
485511

486512
hdr = nb.Nifti1Header()
@@ -490,22 +516,26 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
490516
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)
491517

492518
# BOLD derivative in T1w space
493-
dds = bintfs.DerivativesDataSink(
519+
prep, _ = make_prep_and_save(
520+
interface,
494521
base_directory=str(tmp_path),
495522
keep_dtype=True,
496523
desc="preproc",
497524
source_file=T1W_PATH,
498525
space=space or Undefined,
499526
in_file=fname,
500-
).run()
527+
)
501528

502-
nii = nb.load(dds.outputs.out_file)
503-
assert dds.outputs.fixed_hdr == fixed
529+
prep_result = prep.run()
530+
531+
nii = nb.load(prep_result.outputs.out_file)
532+
assert prep_result.outputs.fixed_hdr == fixed
504533
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
505534
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
506535
assert nii.header.get_xyzt_units() == ("mm", "unknown")
507536

508537

538+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
509539
@pytest.mark.parametrize(
510540
"source_file",
511541
[
@@ -517,7 +547,7 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
517547
@pytest.mark.parametrize("source_dtype", ["<i4", "<f4"])
518548
@pytest.mark.parametrize("in_dtype", ["<i4", "<f4"])
519549
def test_DerivativesDataSink_data_dtype_source(
520-
tmp_path, source_file, source_dtype, in_dtype
550+
tmp_path, interface, source_file, source_dtype, in_dtype
521551
):
522552

523553
def make_empty_nii_with_dtype(fname, dtype):
@@ -539,19 +569,23 @@ def make_empty_nii_with_dtype(fname, dtype):
539569
for s in source_file:
540570
make_empty_nii_with_dtype(s, source_dtype)
541571

542-
dds = bintfs.DerivativesDataSink(
572+
prep, save = make_prep_and_save(
573+
interface,
543574
base_directory=str(tmp_path),
544575
data_dtype="source",
545576
desc="preproc",
546577
source_file=source_file,
547578
in_file=in_file,
548-
).run()
579+
)
580+
581+
prep_result = prep.run()
549582

550-
nii = nb.load(dds.outputs.out_file)
583+
nii = nb.load(prep_result.outputs.out_file)
551584
assert nii.get_data_dtype() == np.dtype(source_dtype)
552585

553586

554-
def test_DerivativesDataSink_fmapid(tmp_path):
587+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
588+
def test_DerivativesDataSink_fmapid(tmp_path, interface):
555589
"""Ascertain #637 is not regressing."""
556590
source_file = [
557591
(tmp_path / s)
@@ -569,7 +603,8 @@ def test_DerivativesDataSink_fmapid(tmp_path):
569603
in_file = tmp_path / "report.svg"
570604
in_file.write_text("")
571605

572-
dds = bintfs.DerivativesDataSink(
606+
prep, save = make_prep_and_save(
607+
interface,
573608
base_directory=str(tmp_path),
574609
datatype="figures",
575610
suffix="fieldmap",
@@ -579,12 +614,19 @@ def test_DerivativesDataSink_fmapid(tmp_path):
579614
fmapid="auto00000",
580615
source_file=[str(s.absolute()) for s in source_file],
581616
in_file=str(in_file),
582-
).run()
583-
assert dds.outputs.out_file.endswith("sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg")
617+
)
584618

619+
prep_result = prep.run()
620+
save_result = connect_and_run_save(prep_result, save)
585621

622+
assert save_result.outputs.out_file.endswith(
623+
"sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg"
624+
)
625+
626+
627+
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
586628
@pytest.mark.parametrize("dtype", ("i2", "u2", "f4"))
587-
def test_DerivativesDataSink_values(tmp_path, dtype):
629+
def test_DerivativesDataSink_values(tmp_path, interface, dtype):
588630
# We use static checksums above, which ensures we don't break things, but
589631
# pins the tests to specific values.
590632
# Here we use random values, check that the values are preserved, and then
@@ -599,16 +641,19 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
599641
orig_data = np.asanyarray(nb.load(fname).dataobj)
600642
expected = np.rint(orig_data) if dtype[0] in "iu" else orig_data
601643

602-
dds = bintfs.DerivativesDataSink(
644+
prep, _ = make_prep_and_save(
645+
interface,
603646
base_directory=str(tmp_path),
604647
keep_dtype=True,
605648
data_dtype=dtype,
606649
desc="preproc",
607650
source_file=T1W_PATH,
608651
in_file=fname,
609-
).run()
652+
)
653+
654+
prep_result = prep.run()
610655

611-
out_file = Path(dds.outputs.out_file)
656+
out_file = Path(prep_result.outputs.out_file)
612657

613658
nii = nb.load(out_file)
614659
assert np.allclose(nii.dataobj, expected)
@@ -617,14 +662,17 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
617662
out_file.unlink()
618663

619664
# Rerun to ensure determinism with non-zero data
620-
dds = bintfs.DerivativesDataSink(
665+
prep, _ = make_prep_and_save(
666+
interface,
621667
base_directory=str(tmp_path),
622668
keep_dtype=True,
623669
data_dtype=dtype,
624670
desc="preproc",
625671
source_file=T1W_PATH,
626672
in_file=fname,
627-
).run()
673+
)
674+
675+
prep_result = prep.run()
628676

629677
assert sha1(out_file.read_bytes()).hexdigest() == checksum
630678

0 commit comments

Comments
 (0)