diff --git a/niworkflows/interfaces/bids.py b/niworkflows/interfaces/bids.py index 21a4e43f4a0..6a50417d400 100644 --- a/niworkflows/interfaces/bids.py +++ b/niworkflows/interfaces/bids.py @@ -28,6 +28,7 @@ import shutil import os import re +import sys import nibabel as nb import numpy as np @@ -48,6 +49,7 @@ SimpleInterface, ) from nipype.interfaces.io import add_traits +from nipype.utils.filemanip import hash_infile import templateflow as tf from .. import data from ..utils.bids import _init_layout, relative_to_root @@ -63,6 +65,15 @@ LOGGER = logging.getLogger("nipype.interface") +if sys.version_info < (3, 10): # PY39 + builtin_zip = zip + + def zip(*args, strict=False): + if strict and any(len(args[0]) != len(arg) for arg in args): + raise ValueError("strict_zip() requires all arguments to have the same length") + return builtin_zip(*args) + + def _none(): return None @@ -288,6 +299,498 @@ def _run_interface(self, runtime): return runtime +class _PrepareDerivativeInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec): + check_hdr = traits.Bool(True, usedefault=True, desc="fix headers of NIfTI outputs") + compress = InputMultiObject( + traits.Either(None, traits.Bool), + usedefault=True, + desc="whether ``in_file`` should be compressed (True), uncompressed (False) " + "or left unmodified (None, default).", + ) + data_dtype = Str( + desc="NumPy datatype to coerce NIfTI data to, or `source` to match the input file dtype" + ) + dismiss_entities = InputMultiObject( + traits.Either(None, Str), + usedefault=True, + desc="a list entities that will not be propagated from the source file", + ) + in_file = InputMultiObject( + File(exists=True), mandatory=True, desc="the object to be saved" + ) + meta_dict = traits.DictStrAny(desc="an input dictionary containing metadata") + source_file = InputMultiObject( + File(exists=False), mandatory=True, desc="the source file(s) to extract entities from") + + +class _PrepareDerivativeOutputSpec(TraitedSpec): + out_file = OutputMultiObject(File(exists=True), desc="derivative file path") + out_meta = traits.DictStrAny(desc="derivative metadata") + out_path = OutputMultiObject(Str, desc="relative path in target directory") + fixed_hdr = traits.List(traits.Bool, desc="whether derivative header was fixed") + + +class PrepareDerivative(SimpleInterface): + """ + Prepare derivative files and metadata. + + Collects entities from source files and inputs, filters them for allowed entities, + and constructs a relative path within a BIDS dataset. + + For each file, the interface will determine if any changes to the file contents + are needed, including: + + - Compression (or decompression) of the file + - Coercion of the data type + - Fixing the NIfTI header + - Align qform and sform affines and codes + - Set zooms and units + + If the input file needs to be modified, the interface will write a new file + and return the path to it. If no changes are needed, the interface will return + the path to the input file. + + .. testsetup:: + + >>> data_dir_canary() + + >>> import tempfile + >>> tmpdir = Path(tempfile.mkdtemp()) + >>> tmpfile = tmpdir / 'a_temp_file.nii.gz' + >>> tmpfile.open('w').close() # "touch" the file + >>> t1w_source = bids_collect_data( + ... str(datadir / 'ds114'), '01', bids_validate=False)[0]['t1w'][0] + >>> prep = PrepareDerivative(check_hdr=False) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = t1w_source + >>> prep.inputs.desc = 'denoised' + >>> prep.inputs.compress = False + >>> res = prep.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS + '.../a_temp_file.nii.gz' + >>> res.outputs.out_path # doctest: +ELLIPSIS + 'sub-01/ses-retest/anat/sub-01_ses-retest_desc-denoised_T1w.nii' + + >>> tmpfile = tmpdir / 'a_temp_file.nii' + >>> tmpfile.open('w').close() # "touch" the file + >>> prep = PrepareDerivative(check_hdr=False, allowed_entities=("custom",)) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = t1w_source + >>> prep.inputs.custom = 'noise' + >>> res = prep.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS + '.../a_temp_file.nii' + >>> res.outputs.out_path # doctest: +ELLIPSIS + 'sub-01/ses-retest/anat/sub-01_ses-retest_custom-noise_T1w.nii' + + >>> prep = PrepareDerivative(check_hdr=False, allowed_entities=("custom",)) + >>> prep.inputs.in_file = [str(tmpfile), str(tmpfile)] + >>> prep.inputs.source_file = t1w_source + >>> prep.inputs.custom = [1, 2] + >>> prep.inputs.compress = True + >>> res = prep.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['.../a_temp_file.nii', '.../a_temp_file.nii'] + >>> res.outputs.out_path # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['sub-01/ses-retest/anat/sub-01_ses-retest_custom-1_T1w.nii.gz', + 'sub-01/ses-retest/anat/sub-01_ses-retest_custom-2_T1w.nii.gz'] + + >>> prep = PrepareDerivative(check_hdr=False, allowed_entities=("custom1", "custom2")) + >>> prep.inputs.in_file = [str(tmpfile)] * 2 + >>> prep.inputs.source_file = t1w_source + >>> prep.inputs.custom1 = [1, 2] + >>> prep.inputs.custom2 = "b" + >>> res = prep.run() + >>> res.outputs.out_file # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['.../a_temp_file.nii', '.../a_temp_file.nii'] + >>> res.outputs.out_path # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['sub-01/ses-retest/anat/sub-01_ses-retest_custom1-1_custom2-b_T1w.nii', + 'sub-01/ses-retest/anat/sub-01_ses-retest_custom1-2_custom2-b_T1w.nii'] + + When multiple source files are passed, only common entities are passed down. + For example, if two T1w images from different sessions are used to generate + a single image, the session entity is removed automatically. + + >>> bids_dir = tmpdir / 'bidsroot' + >>> multi_source = [ + ... bids_dir / 'sub-02/ses-A/anat/sub-02_ses-A_T1w.nii.gz', + ... bids_dir / 'sub-02/ses-B/anat/sub-02_ses-B_T1w.nii.gz'] + >>> for source_file in multi_source: + ... source_file.parent.mkdir(parents=True, exist_ok=True) + ... _ = source_file.write_text("") + >>> prep = PrepareDerivative(check_hdr=False) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = list(map(str, multi_source)) + >>> prep.inputs.desc = 'preproc' + >>> res = prep.run() + >>> res.outputs.out_path # doctest: +ELLIPSIS + 'sub-02/anat/sub-02_desc-preproc_T1w.nii' + + If, on the other hand, only one is used, the session is preserved: + + >>> prep.inputs.source_file = str(multi_source[0]) + >>> res = prep.run() + >>> res.outputs.out_path # doctest: +ELLIPSIS + 'sub-02/ses-A/anat/sub-02_ses-A_desc-preproc_T1w.nii' + + >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' + >>> bids_dir.mkdir(parents=True, exist_ok=True) + >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' + >>> tricky_source.open('w').close() + >>> prep = PrepareDerivative(check_hdr=False) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = str(tricky_source) + >>> prep.inputs.desc = 'preproc' + >>> res = prep.run() + >>> res.outputs.out_path # doctest: +ELLIPSIS + 'sub-02/ses-noanat/func/sub-02_ses-noanat_task-rest_run-01_desc-preproc_bold.nii' + + >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' + >>> bids_dir.mkdir(parents=True, exist_ok=True) + >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' + >>> tricky_source.open('w').close() + >>> prep = PrepareDerivative(check_hdr=False) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = str(tricky_source) + >>> prep.inputs.desc = 'preproc' + >>> prep.inputs.RepetitionTime = 0.75 + >>> res = prep.run() + >>> res.outputs.out_meta # doctest: +ELLIPSIS + {'RepetitionTime': 0.75} + + >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' + >>> bids_dir.mkdir(parents=True, exist_ok=True) + >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' + >>> tricky_source.open('w').close() + >>> prep = PrepareDerivative(check_hdr=False, SkullStripped=True) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = str(tricky_source) + >>> prep.inputs.desc = 'preproc' + >>> prep.inputs.space = 'MNI152NLin6Asym' + >>> prep.inputs.resolution = '01' + >>> prep.inputs.RepetitionTime = 0.75 + >>> res = prep.run() + >>> res.outputs.out_meta # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + {'SkullStripped': True, + 'RepetitionTime': 0.75, + 'Resolution': 'Template MNI152NLin6Asym (1.0x1.0x1.0 mm^3)...'} + + >>> bids_dir = tmpdir / 'bidsroot' / 'sub-02' / 'ses-noanat' / 'func' + >>> bids_dir.mkdir(parents=True, exist_ok=True) + >>> tricky_source = bids_dir / 'sub-02_ses-noanat_task-rest_run-01_bold.nii.gz' + >>> tricky_source.open('w').close() + >>> prep = PrepareDerivative(check_hdr=False, SkullStripped=True) + >>> prep.inputs.in_file = str(tmpfile) + >>> prep.inputs.source_file = str(tricky_source) + >>> prep.inputs.desc = 'preproc' + >>> prep.inputs.resolution = 'native' + >>> prep.inputs.space = 'MNI152NLin6Asym' + >>> prep.inputs.RepetitionTime = 0.75 + >>> prep.inputs.meta_dict = {'RepetitionTime': 1.75, 'SkullStripped': False, 'Z': 'val'} + >>> res = prep.run() + >>> res.outputs.out_meta # doctest: +ELLIPSIS + {'RepetitionTime': 0.75, 'SkullStripped': True, 'Z': 'val'} + + """ + + input_spec = _PrepareDerivativeInputSpec + output_spec = _PrepareDerivativeOutputSpec + _config_entities = frozenset({e["name"] for e in BIDS_DERIV_ENTITIES}) + _config_entities_dict = BIDS_DERIV_ENTITIES + _standard_spaces = STANDARD_SPACES + _file_patterns = BIDS_DERIV_PATTERNS + _default_dtypes = DEFAULT_DTYPES + + def __init__(self, allowed_entities=None, **inputs): + """Initialize the SimpleInterface and extend inputs with custom entities.""" + self._allowed_entities = set(allowed_entities or []).union( + set(self._config_entities) + ) + + self._metadata = {} + self._static_traits = self.input_spec.class_editable_traits() + sorted( + self._allowed_entities + ) + for dynamic_input in set(inputs) - set(self._static_traits): + self._metadata[dynamic_input] = inputs.pop(dynamic_input) + + # First regular initialization (constructs InputSpec object) + super().__init__(**inputs) + add_traits(self.inputs, self._allowed_entities) + for k in self._allowed_entities.intersection(list(inputs.keys())): + # Add additional input fields (self.inputs is an object) + setattr(self.inputs, k, inputs[k]) + + def _run_interface(self, runtime): + from bids.layout import parse_file_entities, Config + from bids.layout.writing import build_path + from bids.utils import listify + + # Metadata applies to all files, and is not subject to change + metadata = { + # Lowest precedence: metadata provided as a dictionary + **(self.inputs.meta_dict or {}), + # Middle precedence: metadata passed to constructor + **self._metadata, + # Highest precedence: metadata set as inputs + **({ + k: getattr(self.inputs, k) + for k in self.inputs.copyable_trait_names() + if k not in self._static_traits + }) + } + + in_file = listify(self.inputs.in_file) + + # Initialize entities with those from the source file. + custom_config = Config( + name="custom", + entities=self._config_entities_dict, + default_path_patterns=self._file_patterns, + ) + in_entities = [ + parse_file_entities( + str(relative_to_root(source_file)), + config=["bids", "derivatives", custom_config], + ) + for source_file in self.inputs.source_file + ] + out_entities = {k: v for k, v in in_entities[0].items() + if all(ent.get(k) == v for ent in in_entities[1:])} + for drop_entity in listify(self.inputs.dismiss_entities or []): + out_entities.pop(drop_entity, None) + + # Override extension with that of the input file(s) + out_entities["extension"] = [ + # _splitext does not accept .surf.gii (for instance) + "".join(Path(orig_file).suffixes).lstrip(".") + for orig_file in in_file + ] + + compress = listify(self.inputs.compress) or [None] + if len(compress) == 1: + compress = compress * len(in_file) + for i, ext in enumerate(out_entities["extension"]): + if compress[i] is not None: + ext = regz.sub("", ext) + out_entities["extension"][i] = f"{ext}.gz" if compress[i] else ext + + # Override entities with those set as inputs + for key in self._allowed_entities: + value = getattr(self.inputs, key) + if value is not None and isdefined(value): + out_entities[key] = value + + # Clean up native resolution with space + if out_entities.get("resolution") == "native" and out_entities.get("space"): + out_entities.pop("resolution", None) + + # Expand templateflow resolutions + resolution = out_entities.get("resolution") + space = out_entities.get("space") + if resolution: + # Standard spaces + if space in self._standard_spaces: + res = _get_tf_resolution(space, resolution) + else: # TODO: Nonstandard? + res = "Unknown" + metadata['Resolution'] = res + + if len(set(out_entities["extension"])) == 1: + out_entities["extension"] = out_entities["extension"][0] + + # Insert custom (non-BIDS) entities from allowed_entities. + custom_entities = set(out_entities) - set(self._config_entities) + patterns = self._file_patterns + if custom_entities: + # Example: f"{key}-{{{key}}}" -> "task-{task}" + custom_pat = "_".join(f"{key}-{{{key}}}" for key in sorted(custom_entities)) + patterns = [ + pat.replace("_{suffix", "_".join(("", custom_pat, "{suffix"))) + for pat in patterns + ] + + # Build the output path(s) + dest_files = build_path(out_entities, path_patterns=patterns) + if not dest_files: + raise ValueError(f"Could not build path with entities {out_entities}.") + + # Make sure the interpolated values is embedded in a list, and check + dest_files = listify(dest_files) + if len(in_file) != len(dest_files): + raise ValueError( + f"Input files ({len(in_file)}) not matched " + f"by interpolated patterns ({len(dest_files)})." + ) + + # Prepare SimpleInterface outputs object + self._results["out_file"] = [] + self._results["fixed_hdr"] = [False] * len(in_file) + self._results["out_path"] = dest_files + self._results["out_meta"] = metadata + + for i, (orig_file, dest_file) in enumerate(zip(in_file, dest_files)): + # Set data and header iff changes need to be made. If these are + # still None when it's time to write, just copy. + new_data, new_header = None, None + + is_nifti = False + with suppress(nb.filebasedimages.ImageFileError): + is_nifti = isinstance(nb.load(orig_file), nb.Nifti1Image) + + new_compression = False + if is_nifti: + new_compression = ( + os.fspath(orig_file).endswith(".gz") ^ os.fspath(dest_file).endswith(".gz") + ) + + data_dtype = self.inputs.data_dtype or self._default_dtypes[self.inputs.suffix] + if is_nifti and any((self.inputs.check_hdr, data_dtype)): + nii = nb.load(orig_file) + + if self.inputs.check_hdr: + hdr = nii.header + curr_units = tuple( + [None if u == "unknown" else u for u in hdr.get_xyzt_units()] + ) + curr_codes = (int(hdr["qform_code"]), int(hdr["sform_code"])) + + # Default to mm, use sec if data type is bold + units = ( + curr_units[0] or "mm", + "sec" if out_entities["suffix"] == "bold" else None, + ) + xcodes = (1, 1) # Derivative in its original scanner space + if self.inputs.space: + xcodes = ( + (4, 4) if self.inputs.space in self._standard_spaces else (2, 2) + ) + + curr_zooms = zooms = hdr.get_zooms() + if "RepetitionTime" in self.inputs.get(): + zooms = curr_zooms[:3] + (self.inputs.RepetitionTime,) + + if (curr_codes, curr_units, curr_zooms) != (xcodes, units, zooms): + self._results["fixed_hdr"][i] = True + new_header = hdr.copy() + new_header.set_qform(nii.affine, xcodes[0]) + new_header.set_sform(nii.affine, xcodes[1]) + new_header.set_xyzt_units(*units) + new_header.set_zooms(zooms) + + if data_dtype == "source": # match source dtype + try: + data_dtype = nb.load(self.inputs.source_file[0]).get_data_dtype() + except Exception: + LOGGER.warning( + f"Could not get data type of file {self.inputs.source_file[0]}" + ) + data_dtype = None + + if data_dtype: + data_dtype = np.dtype(data_dtype) + orig_dtype = nii.get_data_dtype() + if orig_dtype != data_dtype: + LOGGER.warning( + f"Changing {Path(dest_file).name} dtype " + f"from {orig_dtype} to {data_dtype}" + ) + # coerce dataobj to new data dtype + if np.issubdtype(data_dtype, np.integer): + new_data = np.rint(nii.dataobj).astype(data_dtype) + else: + new_data = np.asanyarray(nii.dataobj, dtype=data_dtype) + # and set header to match + if new_header is None: + new_header = nii.header.copy() + new_header.set_data_dtype(data_dtype) + del nii + + if new_data is new_header is None and not new_compression: + out_file = orig_file + else: + out_file = Path(runtime.cwd) / Path(dest_file).name + + orig_img = nb.load(orig_file) + + if new_header is None: + new_header = orig_img.header.copy() + + if new_data is None: + set_consumables(new_header, orig_img.dataobj) + new_data = orig_img.dataobj.get_unscaled() + else: + # Without this, we would be writing nans + # This is our punishment for hacking around nibabel defaults + new_header.set_slope_inter(slope=1., inter=0.) + unsafe_write_nifti_header_and_data( + fname=out_file, + header=new_header, + data=new_data + ) + del orig_img + + self._results["out_file"].append(str(out_file)) + + return runtime + + +class _SaveDerivativeInputSpec(TraitedSpec): + base_directory = Directory( + exists=True, mandatory=True, desc="Path to the base directory for storing data." + ) + in_file = InputMultiObject( + File(exists=True), mandatory=True, desc="the object to be saved" + ) + metadata = traits.DictStrAny(desc="metadata to be saved alongside the file") + relative_path = InputMultiObject( + traits.Str, desc="path to the file relative to the base directory" + ) + + +class _SaveDerivativeOutputSpec(TraitedSpec): + out_file = OutputMultiObject(File, desc="written file path") + out_meta = OutputMultiObject(File, desc="written JSON sidecar path") + + +class SaveDerivative(SimpleInterface): + """Save a prepared derivative file. + + This interface is intended to be used after the PrepareDerivative interface. + Its main purpose is to copy data to the output directory if an identical copy + is not already present. + + This ensures that changes to the output directory metadata (e.g., mtime) do not + trigger unnecessary recomputations in the workflow. + """ + input_spec = _SaveDerivativeInputSpec + output_spec = _SaveDerivativeOutputSpec + _always_run = True + + def _run_interface(self, runtime): + self._results["out_file"] = [] + self._results["out_meta"] = [] + + for in_file, relative_path in zip( + self.inputs.in_file, self.inputs.relative_path, + strict=True, + ): + out_file = Path(self.inputs.base_directory) / relative_path + out_file.parent.mkdir(exist_ok=True, parents=True) + + if not out_file.exists() or hash_infile(in_file) != hash_infile(out_file): + _copy_any(in_file, out_file) + + if self.inputs.metadata: + sidecar = out_file.parent / f"{out_file.name.split('.', 1)[0]}.json" + sidecar.unlink(missing_ok=True) + sidecar.write_text(dumps(self.inputs.metadata, indent=2)) + self._results["out_meta"].append(str(sidecar)) + self._results["out_file"].append(str(out_file)) + + return runtime + + class _DerivativesDataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec): base_directory = traits.Directory( desc="Path to the base directory for storing data." diff --git a/niworkflows/interfaces/tests/test_bids.py b/niworkflows/interfaces/tests/test_bids.py index 365a62356a8..33c30df18fd 100644 --- a/niworkflows/interfaces/tests/test_bids.py +++ b/niworkflows/interfaces/tests/test_bids.py @@ -47,6 +47,35 @@ BOLD_PATH = "ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz" +def make_prep_and_save( + prep_interface, + base_directory, + out_path_base=None, + **kwargs, +): + if prep_interface is bintfs.DerivativesDataSink: + kwargs.update(out_path_base=out_path_base, base_directory=base_directory) + + prep = save = prep_interface(**kwargs) + + if prep_interface is not bintfs.DerivativesDataSink: + save = bintfs.SaveDerivative(base_directory=base_directory) + + return prep, save + + +def connect_and_run_save(prep_result, save): + if prep_result.interface is bintfs.DerivativesDataSink: + return prep_result + + save.inputs.in_file = prep_result.outputs.out_file + save.inputs.relative_path = prep_result.outputs.out_path + save.inputs.metadata = prep_result.outputs.out_meta + + return save.run() + + +@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative]) @pytest.mark.parametrize("out_path_base", [None, "fmriprep"]) @pytest.mark.parametrize( "source,input_files,entities,expectation,checksum", @@ -258,6 +287,7 @@ @pytest.mark.parametrize("dismiss_entities", [None, ("run", "session")]) def test_DerivativesDataSink_build_path( tmp_path, + interface, out_path_base, source, input_files, @@ -267,6 +297,8 @@ def test_DerivativesDataSink_build_path( dismiss_entities, ): """Check a few common derivatives generated by NiPreps.""" + if interface is bintfs.PrepareDerivative and out_path_base is not None: + pytest.skip("PrepareDerivative does not support out_path_base") ds_inputs = [] for input_file in input_files: fname = tmp_path / input_file @@ -291,24 +323,31 @@ def test_DerivativesDataSink_build_path( ds_inputs.append(str(fname)) - dds = bintfs.DerivativesDataSink( + base_directory = tmp_path / "output" + base_directory.mkdir() + + prep, save = make_prep_and_save( + interface, + base_directory=str(base_directory), + out_path_base=out_path_base, in_file=ds_inputs, - base_directory=str(tmp_path), source_file=source, - out_path_base=out_path_base, dismiss_entities=dismiss_entities, **entities, ) - if isinstance(expectation, type): with pytest.raises(expectation): - dds.run() + prep.run() return - output = dds.run().outputs.out_file + prep_result = prep.run() + save_result = connect_and_run_save(prep_result, save) + + output = save_result.outputs.out_file if isinstance(expectation, str): expectation = [expectation] output = [output] + checksum = [checksum] if dismiss_entities: if "run" in dismiss_entities: @@ -320,26 +359,12 @@ def test_DerivativesDataSink_build_path( for e in expectation ] - base = out_path_base or "niworkflows" + base = (out_path_base or "niworkflows") if interface == bintfs.DerivativesDataSink else "" for out, exp in zip(output, expectation): - assert Path(out).relative_to(tmp_path) == Path(base) / exp - - os.chdir(str(tmp_path)) # Exercise without setting base_directory - dds = bintfs.DerivativesDataSink( - in_file=ds_inputs, - dismiss_entities=dismiss_entities, - source_file=source, - out_path_base=out_path_base, - **entities, - ) - - output = dds.run().outputs.out_file - if isinstance(output, str): - output = [output] - checksum = [checksum] + assert Path(out).relative_to(base_directory) == Path(base) / exp for out, exp in zip(output, expectation): - assert Path(out).relative_to(tmp_path) == Path(base) / exp + assert Path(out).relative_to(base_directory) == Path(base) / exp # Regression - some images were given nan scale factors if out.endswith(".nii") or out.endswith(".nii.gz"): img = nb.load(out) @@ -356,7 +381,8 @@ def test_DerivativesDataSink_build_path( assert sha1(Path(out).read_bytes()).hexdigest() == chksum -def test_DerivativesDataSink_dtseries_json(tmp_path): +@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative]) +def test_DerivativesDataSink_dtseries_json(tmp_path, interface): cifti_fname = str(tmp_path / "test.dtseries.nii") axes = (nb.cifti2.SeriesAxis(start=0, step=2, size=20), @@ -371,20 +397,22 @@ def test_DerivativesDataSink_dtseries_json(tmp_path): source_file.parent.mkdir(parents=True) source_file.touch() - dds = bintfs.DerivativesDataSink( - in_file=cifti_fname, + prep, save = make_prep_and_save( + interface, base_directory=str(tmp_path), + out_path_base="", + in_file=cifti_fname, source_file=str(source_file), compress=False, - out_path_base="", space="fsLR", grayordinates="91k", RepetitionTime=2.0, ) - res = dds.run() + prep_result = prep.run() + save_result = connect_and_run_save(prep_result, save) - out_path = Path(res.outputs.out_file) + out_path = Path(save_result.outputs.out_file) assert out_path.name == "sub-01_task-rest_space-fsLR_bold.dtseries.nii" old_sidecar = out_path.with_name("sub-01_task-rest_space-fsLR_bold.dtseries.json") @@ -395,6 +423,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path): assert "RepetitionTime" in json.loads(new_sidecar.read_text()) +@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative]) @pytest.mark.parametrize( "space, size, units, xcodes, zipped, fixed, data_dtype", [ @@ -420,7 +449,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path): ], ) def test_DerivativesDataSink_bold( - tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype + tmp_path, interface, space, size, units, xcodes, zipped, fixed, data_dtype ): fname = str(tmp_path / "source.nii") + (".gz" if zipped else "") @@ -431,7 +460,8 @@ def test_DerivativesDataSink_bold( nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname) # BOLD derivative in T1w space - dds = bintfs.DerivativesDataSink( + prep, _ = make_prep_and_save( + interface, base_directory=str(tmp_path), keep_dtype=True, data_dtype=data_dtype or Undefined, @@ -439,10 +469,12 @@ def test_DerivativesDataSink_bold( source_file=BOLD_PATH, space=space or Undefined, in_file=fname, - ).run() + ) - nii = nb.load(dds.outputs.out_file) - assert dds.outputs.fixed_hdr == fixed + prep_result = prep.run() + + nii = nb.load(prep_result.outputs.out_file) + assert prep_result.outputs.fixed_hdr == fixed if data_dtype: assert nii.get_data_dtype() == np.dtype(data_dtype) assert int(nii.header["qform_code"]) == XFORM_CODES[space] @@ -450,6 +482,7 @@ def test_DerivativesDataSink_bold( assert nii.header.get_xyzt_units() == ("mm", "sec") +@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative]) @pytest.mark.parametrize( "space, size, units, xcodes, fixed", [ @@ -473,7 +506,7 @@ def test_DerivativesDataSink_bold( (None, (30, 30, 30), (None, "sec"), (0, 0), [True]), ], ) -def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed): +def test_DerivativesDataSink_t1w(tmp_path, interface, space, size, units, xcodes, fixed): fname = str(tmp_path / "source.nii.gz") hdr = nb.Nifti1Header() @@ -483,22 +516,26 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed): nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname) # BOLD derivative in T1w space - dds = bintfs.DerivativesDataSink( + prep, _ = make_prep_and_save( + interface, base_directory=str(tmp_path), keep_dtype=True, desc="preproc", source_file=T1W_PATH, space=space or Undefined, in_file=fname, - ).run() + ) - nii = nb.load(dds.outputs.out_file) - assert dds.outputs.fixed_hdr == fixed + prep_result = prep.run() + + nii = nb.load(prep_result.outputs.out_file) + assert prep_result.outputs.fixed_hdr == fixed assert int(nii.header["qform_code"]) == XFORM_CODES[space] assert int(nii.header["sform_code"]) == XFORM_CODES[space] assert nii.header.get_xyzt_units() == ("mm", "unknown") +@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative]) @pytest.mark.parametrize( "source_file", [ @@ -510,7 +547,7 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed): @pytest.mark.parametrize("source_dtype", ["