From d85f8539580196c8b8467a401f5a9f2701352a99 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Thu, 28 Apr 2022 15:26:14 +0300 Subject: [PATCH 1/7] multimodality example --- .../classification/multimodality/dataset.py | 169 +++++++ .../multimodality/input_processor.py | 25 + .../loss_multimodal_contrastive_learning.py | 63 +++ .../multimodality/mg_dataset.py | 432 ++++++++++++++++++ .../multimodality/mg_dataset_radiologist.py | 308 +++++++++++++ .../multimodality/model_tabular_imaging.py | 295 ++++++++++++ .../multimodality/multimodal_paths.py | 48 ++ .../multimodality/multimodel_parameters.py | 216 +++++++++ .../classification/multimodality/runner.py | 391 ++++++++++++++++ 9 files changed, 1947 insertions(+) create mode 100644 fuse_examples/classification/multimodality/dataset.py create mode 100644 fuse_examples/classification/multimodality/input_processor.py create mode 100644 fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py create mode 100644 fuse_examples/classification/multimodality/mg_dataset.py create mode 100644 fuse_examples/classification/multimodality/mg_dataset_radiologist.py create mode 100644 fuse_examples/classification/multimodality/model_tabular_imaging.py create mode 100644 fuse_examples/classification/multimodality/multimodal_paths.py create mode 100644 fuse_examples/classification/multimodality/multimodel_parameters.py create mode 100644 fuse_examples/classification/multimodality/runner.py diff --git a/fuse_examples/classification/multimodality/dataset.py b/fuse_examples/classification/multimodality/dataset.py new file mode 100644 index 000000000..4cf96ce47 --- /dev/null +++ b/fuse_examples/classification/multimodality/dataset.py @@ -0,0 +1,169 @@ +import sys +from typing import Callable, Optional +import logging +import pandas as pd +import pydicom +import os, glob +from pathlib import Path +from typing import Tuple + +from fuse.data.visualizer.visualizer_default import FuseVisualizerDefault +from fuse.data.augmentor.augmentor_default import FuseAugmentorDefault +from fuse.data.augmentor.augmentor_toolbox import aug_op_color, aug_op_gaussian, aug_op_affine +from fuse.data.dataset.dataset_default import FuseDatasetDefault +from fuse.data.dataset.dataset_generator import FuseDatasetGenerator +from fuse.data.data_source.data_source_default import FuseDataSourceDefault + +from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform +from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt +from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool + + +from fuse_examples.classification.multimodality.input_processor import ImagingTabularProcessor + + + + + +def IMAGING_dataset(): + """ + Creates Fuse Dataset object for training, validation and test + :param data_dir: dataset root path + :param data_misc_dir path to save misc files to be used later + :param cache_dir: Optional, name of the cache folder + :param reset_cache: Optional,specifies if we want to clear the cache first + :param post_cache_processing_func: Optional, function run post cache processing + :return: training, validation and test FuseDatasetDefault objects + """ + augmentation_pipeline = [ + [ + ('data.image',), + aug_op_affine, + {'rotate': Uniform(-30.0, 30.0), 'translate': (RandInt(-10, 10), RandInt(-10, 10)), + 'flip': (RandBool(0.3), RandBool(0.3)), 'scale': Uniform(0.9, 1.1)}, + {'apply': RandBool(0.5)} + ], + [ + ('data.image',), + aug_op_color, + {'add': Uniform(-0.06, 0.06), 'mul': Uniform(0.95, 1.05), 'gamma': Uniform(0.9, 1.1), + 'contrast': Uniform(0.85, 1.15)}, + {'apply': RandBool(0.5)} + ], + [ + ('data.image',), + aug_op_gaussian, + {'std': 0.03}, + {'apply': RandBool(0.5)} + ], + ] + + + + # Create data augmentation (optional) + augmentor = FuseAugmentorDefault( + augmentation_pipeline=augmentation_pipeline) + + + + + return augmentor + + +def TABULAR_dataset(tabular_processor,df,tabular_features,sample_key): + tabular_features.remove(sample_key) + tabular_processor = tabular_processor(data=df, + sample_desc_column=sample_key, + columns_to_extract=tabular_features + [sample_key], + columns_to_tensor=tabular_features) + return tabular_processor + + +def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:str,img_key:str,tabular_features_lst: list,sample_key: str, + cache_dir: str = 'cache', reset_cache: bool = False, + post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]: + + + lgr = logging.getLogger('Fuse') + + if isinstance(df,list): + df_train = df[0] + if len(df)>1: + df_val = df[1] + if len(df)>2: + df_test = df[2] + + #---------------------------------------------- + # -----Datasource + train_data_source = FuseDataSourceDefault(input_source=df_train) + validation_data_source = FuseDataSourceDefault(input_source=df_val) + test_data_source = FuseDataSourceDefault(input_source=df_test) + + # ---------------------------------------------- + # -----Data-processors + img_clinical_processor_train = ImagingTabularProcessor(data=df_train, + label=label_key, + img_key = img_key, + image_processor=imaging_processor(''), + tabular_processor= \ + TABULAR_dataset(tabular_processor,df_train,tabular_features_lst.copy(),sample_key)) + + img_clinical_processor_val = ImagingTabularProcessor(data=df_val, + label=label_key, + img_key=img_key, + image_processor=imaging_processor(''), + tabular_processor=\ + TABULAR_dataset(tabular_processor,df_val,tabular_features_lst.copy(),sample_key)) + + img_clinical_processor_test = ImagingTabularProcessor(data=df_test, + label=label_key, + img_key=img_key, + image_processor=imaging_processor(''), + tabular_processor= \ + TABULAR_dataset(tabular_processor,df_test,tabular_features_lst.copy(),sample_key)) + + + + visualiser = FuseVisualizerDefault(image_name='data.image', label_name='data.gt') + + + # ---------------------------------------------- + # ------ Dataset + train_dataset = FuseDatasetGenerator(cache_dest=cache_dir, + data_source=train_data_source, + processor=img_clinical_processor_train, + augmentor=IMAGING_dataset(), + visualizer=visualiser, + post_processing_func=post_cache_processing_func,) + + + validation_dataset = FuseDatasetGenerator(cache_dest=cache_dir, + data_source=validation_data_source, + processor=img_clinical_processor_val, + augmentor=None, + visualizer=visualiser, + post_processing_func=post_cache_processing_func,) + + test_dataset = FuseDatasetGenerator(cache_dest=cache_dir, + data_source=test_data_source, + processor=img_clinical_processor_test, + augmentor=None, + visualizer=visualiser, + post_processing_func=post_cache_processing_func,) + + + # ---------------------------------------------- + # ------ Cache + + # create cache + train_dataset.create(reset_cache=reset_cache) # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading + validation_dataset.create() # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading + test_dataset.create() # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading + + lgr.info(f'- Load and cache data:') + + lgr.info(f'- Load and cache data: Done') + + return train_dataset, validation_dataset, test_dataset + + diff --git a/fuse_examples/classification/multimodality/input_processor.py b/fuse_examples/classification/multimodality/input_processor.py new file mode 100644 index 000000000..549ea2427 --- /dev/null +++ b/fuse_examples/classification/multimodality/input_processor.py @@ -0,0 +1,25 @@ +def sample_desc_to_xml_path(df, sample_desc,img_key): + xml_path = df[img_key][df.sample_desc == sample_desc].values + return xml_path +def get_gt_from_tabular_sample(tabular_sample_dict,gt_key): + gt = tabular_sample_dict[gt_key] + tabular_sample_dict.pop(gt_key) + return tabular_sample_dict,gt + +class ImagingTabularProcessor: + def __init__(self, data, label,img_key,image_processor, tabular_processor): + self.image_processor = image_processor + self.tabular_processor = tabular_processor + self.data = data + self.label = label + self.img_key = img_key + def __call__(self, sample_desc): + img_path = sample_desc_to_xml_path(self.data, sample_desc,self.img_key) + tabular_sample_dict = self.tabular_processor(sample_desc) + image_dict_list = self.image_processor(img_path[0][0]) + tabular_sample_dict,gt = get_gt_from_tabular_sample(tabular_sample_dict.copy(), self.label) + img_sample_dict = image_dict_list + sample = tabular_sample_dict + sample['image'] = img_sample_dict + sample['gt'] = gt + return sample \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py b/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py new file mode 100644 index 000000000..a2967dd77 --- /dev/null +++ b/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py @@ -0,0 +1,63 @@ +from typing import Dict +import torch +import torch.nn.functional as F +from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict + + +def softcrossentropyloss(target, logits): + """ + From the pytorch discussion Forum: + https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501 + """ + logprobs = torch.nn.functional.log_softmax(logits, dim=1) + loss = -(target * logprobs).sum() / logits.shape[0] + return loss + + +class FuseLossMultimodalContrastiveLearning: + def __init__(self, + imaging_representations: str = None, + tabular_representations: str = None, + label: str = None, + temperature: float = 1.0, + alpha: float = 0.5 + ) -> None: + self.imaging_representations = imaging_representations + self.tabular_representations = tabular_representations + self.temperature = temperature + self.label = label + self.alpha = alpha + + def __call__(self, batch_dict: Dict) -> torch.Tensor: + # filter batch_dict if required + imaging_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.imaging_representations) + tabular_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.tabular_representations) + label = FuseUtilsHierarchicalDict.get(batch_dict, self.label) + if len(imaging_representations.shape)<2: + imaging_representations = imaging_representations.unsqueeze(dim=0) + if len(imaging_representations.shape) < 2: + tabular_representations = tabular_representations.unsqueeze(dim=0) + imaging_representations = F.normalize(imaging_representations, p=2, dim=1) + tabular_representations = F.normalize(tabular_representations, p=2, dim=1) + label_vec = torch.unsqueeze(label, 0) + mask = torch.eq(torch.transpose(label_vec, 0, 1), label_vec).float() + logits_imaging_tabular = torch.matmul(imaging_representations, torch.transpose(tabular_representations, 0, 1))/self.temperature + logits_tabular_imaging = torch.matmul(tabular_representations, torch.transpose(imaging_representations, 0, 1))/self.temperature + loss_imaging_tabular = softcrossentropyloss(mask, logits_imaging_tabular)/torch.sum(mask, 0) + loss_tabular_imaging = softcrossentropyloss(mask, logits_tabular_imaging)/torch.sum(mask, 0) + return self.alpha*loss_tabular_imaging.sum() + (1-self.alpha)*loss_imaging_tabular.sum() + + +if __name__ == '__main__': + import torch + + batch_dict = {'model.imaging_representations': torch.randn(3, 2), + 'model.tabular_representations': torch.randn(3, 2), + 'data.label': torch.empty(3, dtype=torch.long).random_(2)} + + loss = FuseLossMultimodalContrastiveLearning(temperature=0.1, + imaging_representations='model.imaging_representations', + tabular_representations='model.tabular_representations', + label='data.label') + res = loss(batch_dict) + print('Loss output = ' + str(res)) diff --git a/fuse_examples/classification/multimodality/mg_dataset.py b/fuse_examples/classification/multimodality/mg_dataset.py new file mode 100644 index 000000000..adcb7a1ff --- /dev/null +++ b/fuse_examples/classification/multimodality/mg_dataset.py @@ -0,0 +1,432 @@ +import pandas as pd +import os +from typing import Callable, Optional +from typing import Tuple + +from MedicalAnalyticsCore.DatabaseUtils.selected_studies_queries import get_annotations_and_findings +from MedicalAnalyticsCore.DatabaseUtils.tableResolver import TableResolver +from MedicalAnalyticsCore.DatabaseUtils.connection import create_homer_engine, Connection +from MedicalAnalyticsCore.DatabaseUtils import tableNames +from MedicalAnalyticsCore.DatabaseUtils import db_utils as db + + +# from autogluon.tabular import TabularPredictor +from fuse_examples.classification.multimodality.dataset import IMAGING_TABULAR_dataset +from fuse.data.dataset.dataset_default import FuseDatasetDefault + +from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame + + +from typing import Dict, List +import torch +from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict + +class PostProcessing: + def __init__(self, continuous_tabular_features_lst: List, categorical_tabular_features_lst: List, label_lst: List, + imaging_features_lst: List, non_imaging_features_lst: List, use_imaging: bool, use_non_imaging: bool): + self.continuous_tabular_features_lst = continuous_tabular_features_lst + self.categorical_tabular_features_lst = categorical_tabular_features_lst + self.label_lst = label_lst + self.imaging_features_lst = imaging_features_lst + self.non_imaging_features_lst = non_imaging_features_lst + self.use_imaging = use_imaging + self.use_non_imaging = use_non_imaging + + def __call__(self, batch_dict: Dict) -> Dict: + if not self.use_imaging and not self.use_non_imaging: + raise ValueError('No features are in use') + mask_list = self.use_imaging * self.imaging_features_lst + self.use_non_imaging * self.non_imaging_features_lst + mask_continuous = torch.zeros(len( self.continuous_tabular_features_lst)) + for i in range(len(mask_list)): + try: + mask_continuous[self.continuous_tabular_features_lst.index(mask_list[i])] = 1 + except: + pass + mask_categorical = torch.zeros(len( self.categorical_tabular_features_lst)) + for i in range(len(mask_list)): + try: + mask_categorical[self.categorical_tabular_features_lst.index(mask_list[i])] = 1 + except: + pass + categorical = [FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + feature_name) for feature_name in self.categorical_tabular_features_lst] + for i in range(len(categorical)): + if categorical[i].dim() == 0: + categorical[i] = torch.unsqueeze(categorical[i], 0) + categorical_tensor = torch.cat(tuple(categorical), 0) + categorical_tensor = categorical_tensor.float() + categorical_tensor = torch.mul(categorical_tensor, mask_categorical) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.categorical', categorical_tensor.float()) + continuous = [FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + feature_name) for feature_name in self.continuous_tabular_features_lst] + for i in range(len(continuous)): + if continuous[i].dim() == 0: + continuous[i] = torch.unsqueeze(continuous[i], 0) + continuous_tensor = torch.cat(tuple(continuous), 0) + continuous_tensor = continuous_tensor.float() + continuous_tensor = torch.mul(continuous_tensor, mask_continuous) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.continuous', continuous_tensor.float()) + label = FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + self.label_lst[0]) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.label', label.long()) + feature_lst = self.continuous_tabular_features_lst + self.categorical_tabular_features_lst + self.label_lst + for feature in feature_lst: + FuseUtilsHierarchicalDict.pop(batch_dict, 'data.' + feature) + return batch_dict + + +# feature selection univarient analysis + +def get_selected_features_clinical(output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + + # initialize logger + logger = logging.getLogger("BigMed") + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(sys.stdout)) + handler = logging.FileHandler(join(output_path, "feature_selection_messages.log")) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.info("\n============ Configuration ============\n") + logger.info("{") + for key, value in config.items(): + logger.info("%s: %s" % (key, value)) + logger.info("}\n") + + if 'column_names' in config: + all_df = ClinicalData.read_file(config['filename'], columns_names=config['column_names']).clinical_df + else: + all_df = pd.read_csv(config['filename']) + if 'selected_columns' in config: + all_df = all_df[[sc.SUBJECT_ID] + config['selected_columns'] + [config['label']]] + all_df = all_df.loc[~np.isnan(all_df[config['label']].values), :] # remove rows with missing label + scaler = MinMaxScaler(feature_range=(0, 1)) + + if 'folds_file' in config: + folds_df = pd.read_csv(config['folds_file'], sep=',', header=0, index_col=0) + unq = folds_df['fold'].unique() + n_folds = len(unq)-1 if HELD_OUT_KEY in unq else len(unq) + all_train_rows = folds_df[~folds_df['fold'].isin([HELD_OUT_KEY, -1])] + all_train_subjects = all_train_rows['patient'].tolist() + mask = all_df[sc.SUBJECT_ID].isin(all_train_subjects) + all_train_df = all_df.loc[mask] + else: + n_folds = 1 + + selected_features = {} + for fold in range(n_folds): + if n_folds > 1: + train_rows = folds_df[~folds_df['fold'].isin([fold, HELD_OUT_KEY, -1])] + train_subjects = train_rows['patient'].tolist() + else: + train_subjects = all_df[sc.SUBJECT_ID].tolist() + mask = all_df[sc.SUBJECT_ID].isin(train_subjects) + fold_train_df = all_df.loc[mask] + fold_train_df = fold_train_df.drop([sc.SUBJECT_ID], axis=1) + + fold_train_df = fold_train_df.fillna(fold_train_df.mean()) + y = fold_train_df[config['label']].values + fold_train_df = fold_train_df.drop([config['label']], axis=1) + + logger.info('============ Features selection by SelectFromModel for fold {} ============\n'.format(fold)) + for i in range(len(config['classifiers'])): + cur_df = fold_train_df + if config['apply_scaler'][i]: + cur_df[cur_df.columns] = scaler.fit_transform(cur_df[cur_df.columns]) + X = cur_df.values + cls_name = config['classifiers'][i] + if cls_name == 'RandomForestClassifier': + cls = RandomForestClassifier(**config['classifiers_params'][i]) + elif cls_name == 'LogisticRegression': + cls = LogisticRegression(**config['classifiers_params'][i]) + elif cls_name == 'XGBClassifier': + cls = XGBClassifier(**config['classifiers_params'][i]) + + sfm = SelectFromModel(cls).fit(X, y).get_support() + cur_features = cur_df.columns.values[sfm].tolist() + if cls_name not in selected_features: + selected_features[cls_name] = [cur_features] + else: + selected_features[cls_name] += [cur_features] + logger.info('Out of {} features, found for fold {} the following {} features by SelectFromModel with {}:\n{}\n'.format(len(X[0]), fold, len(cur_features), cls_name, cur_features)) + + logger.info('\n============ Features selection by UnivariateTest for fold {} ============\n'.format(fold)) + _, p = f_classif(fold_train_df.values, y) # analyze variance for all features + cat_features_filter = (fold_train_df.nunique() <= 5) # find categorial features that have at most 5 values + if cat_features_filter.any(): + _, p[cat_features_filter] = chi2(fold_train_df.values[:, cat_features_filter], y) # analyze dependence for categorial features + + sig_columns = [] + for i in np.argsort(p): + name = fold_train_df.columns[i] + if p[i] < config['p_threshold']: + sig_columns += [name] + if 'UnivariateTest' not in selected_features: + selected_features['UnivariateTest'] = [sig_columns] + else: + selected_features['UnivariateTest'] += [sig_columns] + logger.info('Out of {} features, found for fold {} the following {} significant columns in descending order:\n{}\n\n'.format(len(X[0]), fold, len(sig_columns), sig_columns)) + + logger.info('\n============ Features selection Intersection ============\n') + for key in selected_features: + intersect_features = set(selected_features[key][0]).intersection(*selected_features[key]) + logger.info('Out of {} features, found the following {} intersect features for {}:\n{}\n\n'.format(len(X[0]), len(intersect_features), key, list(intersect_features))) + + logger.info('\n============ Features selection by UnivariateTest on all Train data ============\n') + all_train_df = all_train_df.fillna(all_train_df.mean()) + y = all_train_df[config['label']].values + all_train_df = all_train_df.drop([sc.SUBJECT_ID, config['label']], axis=1) + _, p = f_classif(all_train_df.values, y) # analyze variance for all features + cat_features_filter = (all_train_df.nunique() <= 5) # find categorial features that have at most 5 values + if cat_features_filter.any(): + _, p[cat_features_filter] = chi2(all_train_df.values[:, cat_features_filter], y) # analyze dependence for categorial features + sig_columns = [] + for i in np.argsort(p): + name = all_train_df.columns[i] + if p[i] < config['p_threshold']: + sig_columns += [name] + logger.info('Out of {} features, found the following {} significant columns in descending order:\n{}\n\n'.format(len(X[0]), len(sig_columns), sig_columns)) + + handlers = logger.handlers[:] + for handler in handlers: + handler.close() + logger.removeHandler(handler) +#-------------------Tabular +def get_selected_features_mg(data,features_mode,key_columns): + features_dict = tabular_feature_mg(data) + columns_names = list(data.columns) + if features_mode == 'full': + selected_col = \ + features_dict['icd9_feat'] + \ + features_dict['labs_feat'] + \ + features_dict['hrt_feat'] + \ + features_dict['fam_feats'] + \ + features_dict['biopsy_feat'] + \ + features_dict['smoking_feat'] + \ + features_dict['demo_feat'] + \ + features_dict['sympt_feat'] + \ + features_dict['meds_feat'] + \ + features_dict['prev_finding_feat'] + \ + features_dict['gynec_feat'] + \ + features_dict['genetic_feat'] + \ + features_dict['dicom'] + elif features_mode == 'icd9_feat': + selected_col = \ + features_dict['icd9_feat'] + elif features_mode == 'labs_feat': + selected_col = \ + features_dict['labs_feat'] + selected_col = selected_col + key_columns + selected_colIx = [columns_names.index(selected_col[i]) for i in range(len(selected_col))] + return selected_col,selected_colIx + + +def tabular_feature_mg(data): + features_dict = {} + features_dict['icd9_feat'] = [x for x in data if x.startswith('dx_')] + features_dict['labs_feat'] = [x for x in data if x.startswith('labs')] + features_dict['hrt_feat'] = [x for x in data if x.startswith('HRT')] + features_dict['outcome_feat'] = [x for x in data if x.startswith('outcome')] + features_dict['fam_feats'] = [x for x in data if x.startswith('family')] + features_dict['biopsy_feat'] = ['prev_biopsy_result_max', 'past_biopsy_proc_ind', 'past_biopsy_proc_cnt'] + features_dict['smoking_feat'] = [x for x in data if x.startswith('smoking')] + features_dict['radio_feat'] = [x for x in data if 'birads' in x] + [x for x in data if 'breast_density' in x] +\ + [x for x in data if 'breast_MRI' in x] + features_dict['demo_feat'] = [ 'age', 'race', 'religion', 'bmi_max', 'bmi_last', 'weight_max', 'weight_last', + 'osteoporosis_ind', 'bmi_current', 'diabetes_ind']#, 'calc_bmi_current', 'calc_likelihood_obesity'] + features_dict['sympt_feat'] = [ 'pain_cnt', 'nipple_retraction_ind_past', + 'lump_by_dr_ind_past', 'nipple_retraction_ind_current', + 'infection_current_ind', 'lump_by_dr_cnt', 'nipple_retraction_cnt', + 'lump_by_dr_ind_current', 'breast_disorder_ind', + 'breast_disorder_current_ind', 'nipple_allocation_ind_current', 'nipple_allocation_cnt', + 'nipple_allocation_ind_past', 'complaint_ind_current', 'complaint_ind_past', + 'pain_ind_past', 'pain_ind_current', 'infection_current_ind_last'] + features_dict['meds_feat'] = [ 'oral_contraceptives_ind_current', 'progesterons_ind', 'oral_contraceptives_ind_past'] + features_dict['gynec_feat'] = ['has_breastfed_ind', 'children_ind', 'children_cnt', 'age_last_menstruation', 'menopause_ind', + 'age_first_childbirth', 'pregnancies_cnt', 'pregnancies_ind', 'age_first_menstruation', + 'menstruation_years', 'menopause_dx_ind', 'menarche_to_ftp_years'] + features_dict['prev_finding_feat'] = ['prev_high_risk_ind', 'cancer_hist_any_ind', 'prev_benign_cnt', 'prev_benign_ind', + 'prev_high_risk_cnt'] + features_dict['genetic_feat'] = ['genetic_consult_ind'] + features_dict['images'] = ['LCC_micro','RCC_micro','LMLO_micro','RMLO_micro', + 'LCC_pred_classA','LCC_pred_classB', 'LCC_pred_classC', 'LCC_pred_classD', 'LCC_pred_classE', + 'RCC_pred_classA', 'RCC_pred_classB','RCC_pred_classC', 'RCC_pred_classD','RCC_pred_classE', + 'LMLO_pred_classA', 'LMLO_pred_classB', 'LMLO_pred_classC', 'LMLO_pred_classD', 'LMLO_pred_classE', + 'RMLO_pred_classA', 'RMLO_pred_classB', 'RMLO_pred_classC', 'RMLO_pred_classD', 'RMLO_pred_classE', + 'LCC_findings_size', 'RCC_findings_size', 'LMLO_findings_size', 'RMLO_findings_size', + 'LCC_findings_x_max', 'RCC_findings_x_max', 'LMLO_findings_x_max', 'RMLO_findings_x_max', + 'LCC_findings_y_max', 'RCC_findings_y_max', 'LMLO_findings_y_max', 'RMLO_findings_y_max', + 'Calcification', 'Breast Assymetry', 'Tumor', 'Architectural Distortion', 'Axillary lymphadenopathy', + 'spiculated_lesions_report', 'architectural_distortion_report', 'suspicious_calcifications_report'] + features_dict['dicom'] = ['DistanceSourceToPatient_AVG_CC', 'DistanceSourceToDetector_AVG_CC', + 'XRayTubeCurrent_AVG_CC', 'CompressionForce_AVG_CC', + 'ExposureTime_AVG_CC', 'KVP_AVG_CC', 'BodyPartThickness_AVG_CC', + 'RelativeXRayExposure_AVG_CC', 'ExposureInuAs_AVG_CC', + 'DistanceSourceToPatient_AVG_MLO', 'DistanceSourceToDetector_AVG_MLO', + 'XRayTubeCurrent_AVG_MLO', 'CompressionForce_AVG_MLO', + 'ExposureTime_AVG_MLO', 'KVP_AVG_MLO', 'BodyPartThickness_AVG_MLO', + 'RelativeXRayExposure_AVG_MLO', 'ExposureInuAs_AVG_MLO'] + + return features_dict + + +def tabular_mg(tabular_filename,key_columns): + data = pd.read_csv(tabular_filename) + column_names,column_colIx = get_selected_features_mg(data, 'full',key_columns) + df_tabular = data[column_names] + return df_tabular + + +#------------------Imaging +def imaging_mg(imaging_filename,key_columns): + label_column = 'finding_biopsy' + img_sample_column = 'dcm_url' + + if os.path.exists(imaging_filename): + df = pd.read_csv(imaging_filename) + else: + REVISION_DATE = '20200915' + TableResolver().set_revision(REVISION_DATE) + revision = {'prefix': 'sentara', 'suffix': REVISION_DATE} + engine = Connection().get_engine() + + df_with_findings = get_annotations_and_findings(engine, revision, + exam_types=['MG'], viewpoints=None, # ['CC','MLO'], \ + include_findings=True, remove_invalids=True, + remove_heldout=False, \ + remove_excluded=False, remove_less_than_4views=False, \ + load_from_file=False, save_to_file=False) + + # dicom_table = db.get_table_as_dataframe(engine, tableNames.get_dicom_tags_table_name(revision)) + # study_statuses = db.get_table_as_dataframe(engine, tableNames.get_study_statuses_table_name(revision)) + my_providers = ['sentara'] + df = df_with_findings.loc[df_with_findings['provider'].isin(my_providers)] + # fixing assymetry + asymmetries = ['asymmetry', 'developing asymmetry', 'focal asymmetry', 'global asymmetry'] + df['is_asymmetry'] = df['pathology'].isin(asymmetries) + df['is_Breast_Assymetry'] = df['type'].isin(['Breast Assymetry']) + df.loc[df['is_asymmetry'], 'pathology'] = df[df['is_asymmetry']]['biopsy_outcome'] + df.loc[df['is_Breast_Assymetry'], 'pathology'] = df[df['is_Breast_Assymetry']]['biopsy_outcome'] + # remove duble xmls + aa_unsorted = df + aa_unsorted.sort_values('xml_url', ascending=False, inplace=True) + xml_url_to_keep = aa_unsorted.groupby(['image_id'])['xml_url'].transform('first') + df = aa_unsorted[aa_unsorted['xml_url'] == xml_url_to_keep] + remove_from_pathology = ['undefined', 'not_applicable', 'Undefined', 'extracapsular rupture of breast implant', + 'intracapsular rupture of breast implant'] + is_pathology = ~df.pathology.isnull() & ~df.pathology.isin(remove_from_pathology) + is_digital = df.image_source == 'Digital' + is_biopsy = df.finding_biopsy.isin(['negative', 'negative high risk', 'positive']) + df = df[(is_digital) & (is_pathology) & (is_biopsy)] + df.to_csv(imaging_filename) + + df1 = df.groupby(key_columns)[img_sample_column].apply(lambda x: list(map(str, x))).reset_index() + df2 = df.groupby(key_columns)[label_column].apply(lambda x: list(map(str, x))).reset_index() + + + return pd.merge(df1,df2,on=key_columns) + + +#------------------Imaging+Tabular +def merge_datasets(tabular_filename,imaging_filename,key_columns): + tabular_data = tabular_mg(tabular_filename, key_columns) + imaging_data = imaging_mg(imaging_filename, key_columns) + tabular_columns = tabular_data.columns.values + imaging_columns = imaging_data.columns.values + dataset = pd.merge(tabular_data, imaging_data, on=key_columns, how='inner') + return dataset,tabular_columns,imaging_columns + +#------------------Baseline +def apply_gluon_baseline(train_set,test_set,label,save_path): + + predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) + results = predictor.fit_summary(show_plot=True) + + # Inference time: + y_test = test_set[label] + test_data = test_set.drop(labels=[label], + axis=1) # delete labels from test data since we wouldn't have them in practice + print(test_data.head()) + + predictor = TabularPredictor.load( + save_path) + y_pred = predictor.predict_proba(test_data) + perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) + +def MG_dataset(tabular_filename:str, + imaging_filename:str, + train_val_test_filenames:list, + + imaging_processor, + tabular_processor, + + key_columns:list, + label_key:str, + img_key:str, + sample_key: str, + + cache_dir: str = 'cache', + reset_cache: bool = False, + post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]: + + + dataset, tabular_columns, imaging_columns = merge_datasets(tabular_filename, imaging_filename, key_columns) + + dataset['finding_biopsy'] = [1 if 'positive' in sample else 0 for sample in list(dataset[label_key])] + dataset = dataset.loc[:, ~dataset.columns.duplicated()] + dataset.rename(columns={'study_id': sample_key}, inplace=True) + + train_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[0])['study_id'])] + val_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[1])['study_id'])] + test_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[2])['study_id'])] + + features_list = list(tabular_columns) + [features_list.remove(x) for x in key_columns] + train_dataset, validation_dataset, test_dataset = IMAGING_TABULAR_dataset( + df=[train_set, val_set, test_set], + imaging_processor=imaging_processor, + tabular_processor=tabular_processor, + label_key=label_key, + img_key=img_key, + tabular_features_lst=features_list + [label_key] + [sample_key], + sample_key=sample_key, + cache_dir=cache_dir, + reset_cache=reset_cache, + post_cache_processing_func=post_cache_processing_func + ) + + return train_dataset, validation_dataset, test_dataset + + + +if __name__ == "__main__": + data_path = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/mg_clinical_dicom_sentra/' + tabular_filename = os.path.join(data_path, 'fx_sentara_cohort_processed.csv') + imaging_filename = os.path.join(data_path, 'mg_sentara_cohort.csv') + + train_val_test_filenames = [os.path.join(data_path, 'sentara_train_pathologies.csv'), + os.path.join(data_path, 'sentara_val_pathologies.csv'), + os.path.join(data_path, 'sentara_test_pathologies.csv'), ] + + key_columns = ['patient_id', 'study_id'] + fuse_key_column = 'sample_desc' + label_column = 'finding_biopsy' + img_sample_column = 'dcm_url' + train_dataset, validation_dataset, test_dataset = \ + MG_dataset(tabular_filename=tabular_filename, + imaging_filename=imaging_filename, + train_val_test_filenames=train_val_test_filenames, + key_columns=key_columns, + sample_key=fuse_key_column, + label_key=label_column, + img_key=img_sample_column, + cache_dir='./lala/', + reset_cache=False, + imaging_processor=FuseMGInputProcessor, + tabular_processor=FuseProcessorDataFrame, + ) + + + # apply_gluon_baseline(train_set[tabular_columns+[label_column]], + # test_set[tabular_columns+[label_column]],label_column,'./Results/MG+clinical/gluon_baseline/') \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/mg_dataset_radiologist.py b/fuse_examples/classification/multimodality/mg_dataset_radiologist.py new file mode 100644 index 000000000..04eb1a8fa --- /dev/null +++ b/fuse_examples/classification/multimodality/mg_dataset_radiologist.py @@ -0,0 +1,308 @@ +import pandas as pd +import os +from typing import Callable, Optional +from typing import Tuple + +# from MedicalAnalyticsCore.DatabaseUtils.selected_studies_queries import get_annotations_and_findings +# from MedicalAnalyticsCore.DatabaseUtils.tableResolver import TableResolver +# from MedicalAnalyticsCore.DatabaseUtils.connection import create_homer_engine, Connection +# from MedicalAnalyticsCore.DatabaseUtils import tableNames +# from MedicalAnalyticsCore.DatabaseUtils import db_utils as db + + +# from autogluon.tabular import TabularPredictor +from fuse_examples.classification.multimodality.dataset import IMAGING_TABULAR_dataset +from fuse.data.dataset.dataset_default import FuseDatasetDefault + +from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame + + +from typing import Dict, List +import torch +from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict + +class PostProcessing: + def __init__(self, continuous_tabular_features_lst: List, + categorical_tabular_features_lst: List, + label_lst: List, + imaging_features_lst: List, + non_imaging_features_lst: List, + use_imaging: bool, + use_non_imaging: bool): + self.continuous_tabular_features_lst = continuous_tabular_features_lst + self.categorical_tabular_features_lst = categorical_tabular_features_lst + self.label_lst = label_lst + self.imaging_features_lst = imaging_features_lst + self.non_imaging_features_lst = non_imaging_features_lst + self.use_imaging = use_imaging + self.use_non_imaging = use_non_imaging + + def __call__(self, batch_dict: Dict) -> Dict: + if not self.use_imaging and not self.use_non_imaging: + raise ValueError('No features are in use') + mask_list = self.use_imaging * self.imaging_features_lst + self.use_non_imaging * self.non_imaging_features_lst + mask_continuous = torch.zeros(len( self.continuous_tabular_features_lst)) + for i in range(len(mask_list)): + try: + mask_continuous[self.continuous_tabular_features_lst.index(mask_list[i])] = 1 + except: + pass + mask_categorical = torch.zeros(len( self.categorical_tabular_features_lst)) + for i in range(len(mask_list)): + try: + mask_categorical[self.categorical_tabular_features_lst.index(mask_list[i])] = 1 + except: + pass + categorical = [FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + feature_name) for feature_name in self.categorical_tabular_features_lst] + for i in range(len(categorical)): + if categorical[i].dim() == 0: + categorical[i] = torch.unsqueeze(categorical[i], 0) + categorical_tensor = torch.cat(tuple(categorical), 0) + categorical_tensor = categorical_tensor.float() + categorical_tensor = torch.mul(categorical_tensor, mask_categorical) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.categorical', categorical_tensor.float()) + continuous = [FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + feature_name) for feature_name in self.continuous_tabular_features_lst] + for i in range(len(continuous)): + if continuous[i].dim() == 0: + continuous[i] = torch.unsqueeze(continuous[i], 0) + continuous_tensor = torch.cat(tuple(continuous), 0) + continuous_tensor = continuous_tensor.float() + continuous_tensor = torch.mul(continuous_tensor, mask_continuous) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.continuous', continuous_tensor.float()) + label = FuseUtilsHierarchicalDict.get(batch_dict, 'data.' + self.label_lst[0]) + FuseUtilsHierarchicalDict.set(batch_dict, 'data.' + self.label_lst[0], label.long()) + feature_lst = self.continuous_tabular_features_lst + self.categorical_tabular_features_lst + for feature in feature_lst: + FuseUtilsHierarchicalDict.pop(batch_dict, 'data.' + feature) + return batch_dict + + +# feature selection univarient analysis + +#-------------------Tabular +def get_selected_features_mg(data,features_mode,key_columns): + features_dict = tabular_feature_mg() + columns_names = list(data.columns) + if features_mode == 'full': + selected_col = \ + features_dict['continuous_clinical_feat'] + \ + features_dict['categorical_clinical_feat'] + + selected_col = selected_col + key_columns + selected_colIx = [columns_names.index(selected_col[i]) for i in range(len(selected_col))] + return selected_col,selected_colIx + + +def tabular_feature_mg(): + features_dict = {} + + features_dict['continuous_clinical_feat'] = ['findings_size', 'findings_x_max', 'findings_y_max', 'DistanceSourceToPatient', + 'DistanceSourceToDetector', 'x_pixel_spacing', 'XRayTubeCurrent', 'CompressionForce', + 'exposure_time', 'KVP', 'body_part_thickness', 'RelativeXRayExposure', 'exposure_in_mas', + 'age'] #14 continuous clinical features + features_dict['categorical_clinical_feat'] = ['side', 'is_distortions', 'is_spiculations', 'is_susp_calcifications', + 'breast_density_1', 'breast_density_2', 'breast_density_3', + 'breast_density_4', 'final_side_birad_0', 'final_side_birad_1', + 'final_side_birad_2', 'final_side_birad_3', 'final_side_birad_4', + 'final_side_birad_5', 'final_side_birad_6', 'final_side_birad_7', + 'final_side_birad_8', 'birad_0', 'birad_1', 'birad_2', 'birad_3', 'birad_4', + 'birad_5', 'birad_6', 'birad_7', 'birad_8', 'calcification_0', + 'calcification_1', 'calcification_2', 'calcification_3', 'calcification_4', + 'calcification_5', 'calcification_6', 'calcification_7', 'calcification_8', + 'calcification_9', 'longitudinal_change_0', 'longitudinal_change_1', + 'longitudinal_change_2', 'longitudinal_change_3', 'longitudinal_change_4', + 'type_0', 'type_1', 'type_2', 'type_3', 'type_4', 'type_5', 'type_6', 'race_0', + 'race_1', 'race_2', 'race_3', 'race_4', 'race_5', 'race_6', 'race_7', 'race_8', + 'race_9', 'race_10', 'max_prev_birad_class_0', 'max_prev_birad_class_1', + 'max_prev_birad_class_2', 'max_prev_birad_class_3'] #63 categorical clinical features + features_dict['visual_feat'] = ['findings_size', 'findings_x_max', 'findings_y_max', 'side', 'is_distortions', + 'is_spiculations', 'is_susp_calcifications', + 'breast_density_1', 'breast_density_2', 'breast_density_3', + 'breast_density_4', 'final_side_birad_0', 'final_side_birad_1', + 'final_side_birad_2', 'final_side_birad_3', 'final_side_birad_4', + 'final_side_birad_5', 'final_side_birad_6', 'final_side_birad_7', + 'final_side_birad_8', 'birad_0', 'birad_1', 'birad_2', 'birad_3', 'birad_4', + 'birad_5', 'birad_6', 'birad_7', 'birad_8', 'calcification_0', + 'calcification_1', 'calcification_2', 'calcification_3', 'calcification_4', + 'calcification_5', 'calcification_6', 'calcification_7', 'calcification_8', + 'calcification_9', 'longitudinal_change_0', 'longitudinal_change_1', + 'longitudinal_change_2', 'longitudinal_change_3', 'longitudinal_change_4', + 'type_0', 'type_1', 'type_2', 'type_3', 'type_4', 'type_5', 'type_6', + 'max_prev_birad_class_0', 'max_prev_birad_class_1', 'max_prev_birad_class_2', + 'max_prev_birad_class_3'] + features_dict['non_visual_feat'] = ['DistanceSourceToPatient', 'DistanceSourceToDetector', 'x_pixel_spacing', + 'XRayTubeCurrent', + 'CompressionForce', 'exposure_time', 'KVP', 'body_part_thickness', + 'RelativeXRayExposure', + 'exposure_in_mas', 'age', 'race_0', 'race_1', 'race_2', 'race_3', 'race_4', 'race_5', + 'race_6', 'race_7', 'race_8', 'race_9', 'race_10'] + + + + return features_dict + + +def tabular_mg(tabular_filename,key_columns): + data = pd.read_csv(tabular_filename) + column_names,column_colIx = get_selected_features_mg(data, 'full',key_columns) + df_tabular = data[column_names] + return df_tabular + + +#------------------Imaging +def imaging_mg(imaging_filename,key_columns): + label_column = 'finding_biopsy' + img_sample_column = 'dcm_url' + + if os.path.exists(imaging_filename): + df = pd.read_csv(imaging_filename) + else: + REVISION_DATE = '20200915' + TableResolver().set_revision(REVISION_DATE) + revision = {'prefix': 'sentara', 'suffix': REVISION_DATE} + engine = Connection().get_engine() + + df_with_findings = get_annotations_and_findings(engine, revision, + exam_types=['MG'], viewpoints=None, # ['CC','MLO'], \ + include_findings=True, remove_invalids=True, + remove_heldout=False, \ + remove_excluded=False, remove_less_than_4views=False, \ + load_from_file=False, save_to_file=False) + + # dicom_table = db.get_table_as_dataframe(engine, tableNames.get_dicom_tags_table_name(revision)) + # study_statuses = db.get_table_as_dataframe(engine, tableNames.get_study_statuses_table_name(revision)) + my_providers = ['sentara'] + df = df_with_findings.loc[df_with_findings['provider'].isin(my_providers)] + # fixing assymetry + asymmetries = ['asymmetry', 'developing asymmetry', 'focal asymmetry', 'global asymmetry'] + df['is_asymmetry'] = df['pathology'].isin(asymmetries) + df['is_Breast_Assymetry'] = df['type'].isin(['Breast Assymetry']) + df.loc[df['is_asymmetry'], 'pathology'] = df[df['is_asymmetry']]['biopsy_outcome'] + df.loc[df['is_Breast_Assymetry'], 'pathology'] = df[df['is_Breast_Assymetry']]['biopsy_outcome'] + # remove duble xmls + aa_unsorted = df + aa_unsorted.sort_values('xml_url', ascending=False, inplace=True) + xml_url_to_keep = aa_unsorted.groupby(['image_id'])['xml_url'].transform('first') + df = aa_unsorted[aa_unsorted['xml_url'] == xml_url_to_keep] + remove_from_pathology = ['undefined', 'not_applicable', 'Undefined', 'extracapsular rupture of breast implant', + 'intracapsular rupture of breast implant'] + is_pathology = ~df.pathology.isnull() & ~df.pathology.isin(remove_from_pathology) + is_digital = df.image_source == 'Digital' + is_biopsy = df.finding_biopsy.isin(['negative', 'negative high risk', 'positive']) + df = df[(is_digital) & (is_pathology) & (is_biopsy)] + df.to_csv(imaging_filename) + + df1 = df.groupby(key_columns)[img_sample_column].apply(lambda x: list(map(str, x))).reset_index() + df2 = df.groupby(key_columns)[label_column].apply(lambda x: list(map(str, x))).reset_index() + + + return pd.merge(df1,df2,on=key_columns) + + +#------------------Imaging+Tabular +def merge_datasets(tabular_filename,imaging_filename,key_columns): + tabular_data = tabular_mg(tabular_filename, key_columns) + imaging_data = imaging_mg(imaging_filename, key_columns) + tabular_data[key_columns] = tabular_data[key_columns].astype(str) + imaging_data[key_columns] = imaging_data[key_columns].astype(str) + tabular_columns = tabular_data.columns.values + imaging_columns = imaging_data.columns.values + dataset = pd.merge(tabular_data, imaging_data, on=key_columns, how='inner') + return dataset,tabular_columns,imaging_columns + +#------------------Baseline +def apply_gluon_baseline(train_set,test_set,label,save_path): + + predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) + results = predictor.fit_summary(show_plot=True) + + # Inference time: + y_test = test_set[label] + test_data = test_set.drop(labels=[label], + axis=1) # delete labels from test data since we wouldn't have them in practice + print(test_data.head()) + + predictor = TabularPredictor.load( + save_path) + y_pred = predictor.predict_proba(test_data) + perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) + +#MO: thinkabout specific name +def MG_dataset(tabular_filename:str, + imaging_filename:str, + train_val_test_filenames:list, + + #Mo: internal parameters + imaging_processor, + tabular_processor, + + key_columns:list, + label_key:str, + img_key:str, + sample_key: str, + + cache_dir: str = 'cache', + reset_cache: bool = False, + post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]: + + + dataset, tabular_columns, imaging_columns = merge_datasets(tabular_filename, imaging_filename, key_columns) + + dataset['finding_biopsy'] = [1 if 'positive' in sample else 0 for sample in list(dataset[label_key])] + dataset = dataset.loc[:, ~dataset.columns.duplicated()] + dataset.rename(columns={'patient_id': sample_key}, inplace=True) + + train_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[0])['patient_id'])] + val_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[1])['patient_id'])] + test_set = dataset[dataset[sample_key].isin(pd.read_csv(train_val_test_filenames[2])['patient_id'])] + + features_list = list(tabular_columns) + [features_list.remove(x) for x in key_columns] + train_dataset, validation_dataset, test_dataset = IMAGING_TABULAR_dataset( + df=[train_set, val_set, test_set], + imaging_processor=imaging_processor, + tabular_processor=tabular_processor, + label_key=label_key, + img_key=img_key, + tabular_features_lst=features_list + [label_key] + [sample_key], + sample_key=sample_key, + cache_dir=cache_dir, + reset_cache=reset_cache, + post_cache_processing_func=post_cache_processing_func + ) + + return train_dataset, validation_dataset, test_dataset + + + +if __name__ == "__main__": + data_path = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/mg_radiologist_usa/' + tabular_filename = os.path.join(data_path, 'dataset_MG_clinical.csv') + imaging_filename = os.path.join(data_path, 'mg_usa_cohort.csv') + + train_val_test_filenames = [os.path.join(data_path, 'dataset_MG_clinical_train.csv'), + os.path.join(data_path, 'dataset_MG_clinical_validation.csv'), + os.path.join(data_path, 'dataset_MG_clinical_heldout.csv'), ] + + key_columns = ['patient_id'] + fuse_key_column = 'sample_desc' + label_column = 'finding_biopsy' + img_sample_column = 'dcm_url' + train_dataset, validation_dataset, test_dataset = \ + MG_dataset(tabular_filename=tabular_filename, + imaging_filename=imaging_filename, + train_val_test_filenames=train_val_test_filenames, + key_columns=key_columns, + sample_key=fuse_key_column, + label_key=label_column, + img_key=img_sample_column, + cache_dir='./mg_radiologist_usa/', + reset_cache=False, + imaging_processor=FuseMGInputProcessor, + tabular_processor=FuseProcessorDataFrame, + ) + + + # apply_gluon_baseline(train_set[tabular_columns+[label_column]], + # test_set[tabular_columns+[label_column]],label_column,'./Results/MG+clinical/gluon_baseline/') \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/model_tabular_imaging.py b/fuse_examples/classification/multimodality/model_tabular_imaging.py new file mode 100644 index 000000000..08ee2e8a8 --- /dev/null +++ b/fuse_examples/classification/multimodality/model_tabular_imaging.py @@ -0,0 +1,295 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fuse.models.backbones.backbone_mlp import FuseMultilayerPerceptronBackbone +from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict +from typing import Dict, Tuple, Sequence +from fuse.models.backbones.backbone_inception_resnet_v2 import FuseBackboneInceptionResnetV2 + + + +# class Fusefchead(nn.Module): +# def __init__(self, +# cat_representations: Sequence[Tuple[str, int]] = (('model.cat_representations', 1),), +# backbone: FuseMultilayerPerceptronBackbone = FuseMultilayerPerceptronBackbone( +# layers=[2], mlp_input_size=512), +# ) -> None: +# super().__init__() +# +# self.cat_representations = cat_representations +# self.backbone = backbone +# def forward(self, batch_dict: Dict) -> Dict: +# cat_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.cat_representations[0][0]) +# logits = self.backbone(cat_representations) +# preds = F.softmax(logits, dim=1) +# +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.logits', logits) +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.output', preds) +# +# return batch_dict + +# class FuseModelImagingTabularHead(torch.nn.Module): +# def __init__(self, +# backbone: torch.nn.Module, +# heads: Sequence[torch.nn.Module], +# ) -> None: +# super().__init__() +# self.backbone = backbone +# self.heads = torch.nn.ModuleList(heads) +# self.add_module('heads', self.heads) +# +# def forward(self, batch_dict: Dict) -> Dict: +# representations_batch_dict = self.backbone.forward(batch_dict) +# imaging_representations = FuseUtilsHierarchicalDict.get(representations_batch_dict, 'imaging_representations') +# tabular_representations = FuseUtilsHierarchicalDict.get(representations_batch_dict, 'tabular_representations') +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_representations', imaging_representations) +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_representations', tabular_representations) +# if len(imaging_representations.shape)<2: +# imaging_representations = imaging_representations.unsqueeze(dim=0) +# if len(tabular_representations.shape)<2: +# tabular_representations = tabular_representations.unsqueeze(dim=0) +# cat_representations = torch.cat((tabular_representations, imaging_representations), 1) +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.cat_representations', cat_representations) +# for head in self.heads: +# batch_dict = head.forward(batch_dict) +# return batch_dict['model'] + +# class Fusesoftmax(nn.Module): +# def __init__(self, +# logits: Sequence[Tuple[str, int]] = (('model.features', 1),), +# ) -> None: +# super().__init__() +# +# self.logits = logits +# def forward(self, batch_dict: Dict) -> Dict: +# logits = FuseUtilsHierarchicalDict.get(batch_dict, self.logits[0][0]) +# preds = F.softmax(logits, dim=1) +# +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.logits', logits) +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.output', preds) +# +# return batch_dict + +# class FuseModelTabularImaging(torch.nn.Module): +# def __init__(self, +# continuous_tabular_input: Tuple[Tuple[str, int], ...], +# categorical_tabular_input: Tuple[Tuple[str, int], ...], +# imaging_inputs: Tuple[Tuple[str, int], ...], +# backbone_categorical_tabular: torch.nn.Module = None, +# backbone_continuous_tabular: torch.nn.Module = None, +# backbone_imaging: torch.nn.Module = None, +# projection_imaging: nn.Conv2d = nn.Conv2d(384, 256, kernel_size=1, stride=1) +# ) -> None: +# super().__init__() +# self.continuous_tabular_input = continuous_tabular_input +# self.categorical_tabular_input = categorical_tabular_input +# self.imaging_inputs = imaging_inputs +# self.backbone_categorical_tabular = backbone_categorical_tabular +# self.backbone_continuous_tabular = backbone_continuous_tabular +# self.backbone_imaging = backbone_imaging +# self.projection_imaging = projection_imaging +# +# def forward(self, batch_dict: Dict) -> Dict: +# +# #tabular encoder +# categorical_input = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) +# categorical_embeddings = self.backbone_categorical_tabular(categorical_input) +# continuous_input = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) +# input_cat = torch.cat((categorical_embeddings, continuous_input), 1) +# tabular_representations = self.backbone_continuous_tabular(input_cat) #dim 256 +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_representations', tabular_representations) +# +# #imaging encoder +# imaging_input = FuseUtilsHierarchicalDict.get(batch_dict, self.imaging_inputs[0][0]) +# backbone_imaging_features = self.backbone_imaging.forward(imaging_input) +# res = F.max_pool2d(backbone_imaging_features, kernel_size=backbone_imaging_features.shape[2:]) +# imaging_representations = self.projection_imaging.forward(res) +# imaging_representations = torch.squeeze(imaging_representations) +# FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_representations', imaging_representations) +# +# return batch_dict['model'] + + + +# concat model +class TabularImagingConcat(nn.Module): + def __init__(self, pooling='max',projection_imaging: nn.Conv2d = nn.Conv2d(384, 256, kernel_size=1, stride=1)): + super().__init__() + assert pooling in ('max', 'avg') + self.pooling = pooling + self.projection_imaging = projection_imaging + + def fix_imaging(self,imaging_features): + if self.pooling == 'max': + imaging_features = F.max_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + + elif self.pooling == 'avg': + imaging_features = F.avg_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + + imaging_features = self.projection_imaging.forward(imaging_features) + imaging_features = torch.squeeze(torch.squeeze(imaging_features,dim=3),dim=2) + return imaging_features + + + def forward(self, batch_dict): + + imaging_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.imaging_features') + tabular_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.tabular_features') + imaging_features = self.fix_imaging(imaging_features) + res = torch.cat([tabular_features, imaging_features], dim=1) + return res + + +#Tabular model +class FuseModelTabularContinuousCategorical(torch.nn.Module): + def __init__(self, + continuous_tabular_input: Tuple[Tuple[str, int], ...], + categorical_tabular_input: Tuple[Tuple[str, int], ...], + backbone_categorical_tabular: FuseMultilayerPerceptronBackbone, + backbone_continuous_tabular: FuseMultilayerPerceptronBackbone, + heads: Sequence[torch.nn.Module], + ) -> None: + super().__init__() + self.continuous_tabular_input = continuous_tabular_input + self.categorical_tabular_input = categorical_tabular_input + self.backbone_categorical_tabular = backbone_categorical_tabular + self.backbone_cat_tabular = backbone_continuous_tabular + # self.add_module('backbone', self.backbone) + self.heads = torch.nn.ModuleList(heads) + self.add_module('heads', self.heads) + + def forward(self, batch_dict: Dict) -> Dict: + categorical_input = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) + categorical_embeddings = self.backbone_categorical_tabular(categorical_input) + continuous_input = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) + input_cat = torch.cat((categorical_embeddings, continuous_input), 1) + tabular_features = self.backbone_cat_tabular(input_cat) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_features', tabular_features) + + for head in self.heads: + batch_dict = head.forward(batch_dict) + return batch_dict['model'] + +#Tabular Imaging model +class FuseMultiModalityModel(torch.nn.Module): + def __init__(self, + tabular_inputs: Tuple[Tuple[str, int], ...]=None, + imaging_inputs: Tuple[Tuple[str, int], ...]=None, + tabular_backbone: torch.nn.Module=None, + imaging_backbone: torch.nn.Module=None, + multimodal_backbone: torch.nn.Module=None, + tabular_heads: Sequence[torch.nn.Module]=None, + imaging_heads: Sequence[torch.nn.Module]=None, + multimodal_heads: Sequence[torch.nn.Module]=None, + ) -> None: + super().__init__() + + self.tabular_inputs = tabular_inputs + self.tabular_backbone = tabular_backbone + if self.tabular_backbone: + self.add_module('tabular_backbone', self.tabular_backbone) + + self.imaging_inputs = imaging_inputs + self.imaging_backbone = imaging_backbone + if self.imaging_backbone: + self.add_module('imaging_backbone', self.imaging_backbone) + + self.multimodal_backbone = multimodal_backbone + if self.multimodal_backbone: + self.add_module('multimodal_backbone', multimodal_backbone) + + + self.tabular_heads = torch.nn.ModuleList(tabular_heads) + if self.tabular_heads: + self.add_module('tabular_heads', self.tabular_heads) + + self.imaging_heads = torch.nn.ModuleList(imaging_heads) + if self.imaging_heads: + self.add_module('imaging_heads', self.imaging_heads) + + self.multimodal_heads = torch.nn.ModuleList(multimodal_heads) + if self.multimodal_heads: + self.add_module('multimodal_heads', self.multimodal_heads) + + def tabular_modules(self): + return [self.tabular_backbone, self.tabular_heads] + + def imaging_modules(self): + return [self.imaging_backbone, self.imaging_heads] + + def multimodal_modules(self): + return [self.multimodal_backbone, self.multimodal_heads] + + def forward(self, batch_dict: Dict) -> Dict: + + if self.tabular_backbone: + tabular_features = self.tabular_backbone.forward(batch_dict) + + if self.imaging_backbone: + imaging_input = FuseUtilsHierarchicalDict.get(batch_dict, self.imaging_inputs[0][0]) + imaging_features = self.imaging_backbone.forward(imaging_input) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_features', imaging_features) + + if self.multimodal_backbone: + multimodal_features = self.multimodal_backbone.forward(batch_dict) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.multimodal_features', multimodal_features) + + + # run through heads + if self.tabular_heads: + for head in self.tabular_heads: + batch_dict = head.forward(batch_dict) + + if self.imaging_heads: + for head in self.imaging_heads: + batch_dict = head.forward(batch_dict) + + if self.multimodal_heads: + for head in self.multimodal_heads: + batch_dict = head.forward(batch_dict) + + return batch_dict['model'] + + +if __name__ == '__main__': + import torch + + batch_dict = {'data.continuous': torch.randn(8, 14), + 'data.categorical': torch.randn(8, 63), + 'data.image': torch.randn(8, 1, 2200, 1200)} + + # model = FuseModelTabularImaging( + # continuous_tabular_input=(('data.continuous', 1),), + # categorical_tabular_input=(('data.categorical', 1),), + # imaging_inputs=(('data.patch.input.input_0', 1),),) + # + # res = model(batch_dict) + + # model = FuseModelTabularContinuousCategorical( + # continuous_tabular_input=(('data.continuous', 1),), + # categorical_tabular_input=(('data.categorical', 1),), + # backbone_categorical_tabular=FuseMultilayerPerceptronBackbone(layers=[128, 63],mlp_input_size=63), + # backbone_continuous_tabular = FuseMultilayerPerceptronBackbone( layers=[256],mlp_input_size=77), + # heads=None, + # ) + # res = model.forward(batch_dict) + + model_tabular = FuseModelTabularContinuousCategorical( + continuous_tabular_input=(('data.continuous', 1),), + categorical_tabular_input=(('data.categorical', 1),), + backbone_categorical_tabular=FuseMultilayerPerceptronBackbone(layers=[128, 63],mlp_input_size=63), + backbone_continuous_tabular = FuseMultilayerPerceptronBackbone( layers=[256],mlp_input_size=77), + heads=None, + ) + model_imaging = FuseBackboneInceptionResnetV2(input_channels_num=1) + model_multimodel = TabularImagingConcat() + + model = FuseMultiModalityModel( + tabular_inputs=(('data.continuous', 1),('data.categorical', 1),), + imaging_inputs=(('data.image', 1),), + tabular_backbone=model_tabular, + imaging_backbone=model_imaging, + multimodal_backbone = model_multimodel, + ) + res = model.forward(batch_dict) + a=1 \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/multimodal_paths.py b/fuse_examples/classification/multimodality/multimodal_paths.py new file mode 100644 index 000000000..4112c8f12 --- /dev/null +++ b/fuse_examples/classification/multimodality/multimodal_paths.py @@ -0,0 +1,48 @@ +import os + +def multimodal_paths(dataset_name,root_data,root, experiment,cache_path): + if dataset_name=='mg_clinical': + paths = { + # paths + 'data_dir': root_data, + 'tabular_filename': os.path.join(root_data, 'mg_clinical_dicom_sentra/fx_sentara_cohort_processed.csv'), + 'imaging_filename': os.path.join(root_data, 'mg_clinical_dicom_sentra/mg_sentara_cohort.csv'), + 'train_val_test_filenames': [os.path.join(root_data, 'mg_clinical_dicom_sentra/sentara_train_pathologies.csv'), + os.path.join(root_data, 'mg_clinical_dicom_sentra/sentara_val_pathologies.csv'), + os.path.join(root_data, 'mg_clinical_dicom_sentra/sentara_test_pathologies.csv'), ], + + # keys to extract from dataframe + 'key_columns': ['patient_id', 'study_id'], + 'sample_key': 'sample_desc', + 'label_key': 'finding_biopsy', + 'img_key': 'dcm_url', + + 'model_dir': os.path.join(root, experiment, 'model_mg_clinical_dicom_sentra'), + 'force_reset_model_dir': True, + # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. + 'cache_dir': os.path.join(cache_path, '/lala/'), + 'inference_dir': os.path.join(root, experiment, 'infer_mg_clinical_dicom_sentra')} + if dataset_name == 'mg_radiologic': + paths = { + # paths + 'data_dir': root_data, + 'tabular_filename': os.path.join(root_data, 'mg_radiologist_usa/dataset_MG_clinical.csv'), + 'imaging_filename': os.path.join(root_data, 'mg_radiologist_usa/mg_usa_cohort.csv'), + 'train_val_test_filenames': [os.path.join(root_data, 'mg_radiologist_usa/dataset_MG_clinical_train.csv'), + os.path.join(root_data, 'mg_radiologist_usa/dataset_MG_clinical_validation.csv'), + os.path.join(root_data, 'mg_radiologist_usa/dataset_MG_clinical_heldout.csv'), ], + + # keys to extract from dataframe + 'key_columns': ['patient_id'], + 'sample_key': 'sample_desc', + 'label_key': 'finding_biopsy', + 'img_key': 'dcm_url', + + 'model_dir': os.path.join(root_data,'model_mg_radiologist_usa/'+experiment), + 'force_reset_model_dir': True, + # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. + 'cache_dir': os.path.join(cache_path), + 'inference_dir': os.path.join(root_data,'model_mg_radiologist_usa/'+experiment)} + + + return paths \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/multimodel_parameters.py b/fuse_examples/classification/multimodality/multimodel_parameters.py new file mode 100644 index 000000000..ad05b2fab --- /dev/null +++ b/fuse_examples/classification/multimodality/multimodel_parameters.py @@ -0,0 +1,216 @@ +import torch.optim as optim +from fuse_examples.classification.multimodality.model_tabular_imaging import * +from fuse.losses.loss_default import FuseLossDefault +import torch.nn.functional as F +from fuse.metrics.classification.metric_auc import FuseMetricAUC +from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy +from fuse_examples.classification.multimodality.loss_multimodal_contrastive_learning import FuseLossMultimodalContrastiveLearning +from fuse.models.model_default import FuseModelDefault +from fuse.models.heads.head_global_pooling_classifier import FuseHeadGlobalPoolingClassifier +from fuse.models.heads.head_1d_classifier import FuseHead1dClassifier +from fuse.models.model_ensemble import FuseModelEnsemble + +def multimodal_parameters(train_common_params: dict,infer_common_params: dict,analyze_common_params: dict): + + + ################################################ + # backbone_models + model_tabular = FuseModelTabularContinuousCategorical( + continuous_tabular_input=(('data.continuous', 1),), + categorical_tabular_input=(('data.categorical', 1),), + backbone_categorical_tabular=train_common_params['tabular_encoder_categorical'], + backbone_continuous_tabular = train_common_params['tabular_encoder_continuous'], + heads=None, + ) + + model_imaging = train_common_params['imaging_encoder'] + model_multimodel_concat = TabularImagingConcat() + heads_for_multimodal = { + 'multimodal_head': + [ + FuseHead1dClassifier( + head_name='multimodal', + conv_inputs=(('model.multimodal_features', + train_common_params['tabular_feature_size'] * 2),), + num_classes=2, + ) + ], + 'tabular_head': + [ + FuseHead1dClassifier( + head_name='tabular', + conv_inputs=(('model.tabular_features', + train_common_params['tabular_feature_size']),), + num_classes=2, + ) + ], + 'imaging_head': + [ + FuseHeadGlobalPoolingClassifier( + head_name='imaging', + dropout_rate=0.5, + layers_description=(256,), + conv_inputs=(('model.imaging_features', + train_common_params['imaging_feature_size']),), + num_classes=2, + pooling="avg", + ) + ], + + + } + loss_for_multimodal = { + 'multimodal_loss':FuseLossDefault(pred_name='model.logits.multimodal', target_name='data.gt', + callable=F.cross_entropy, weight=1.0), + 'tabular_loss':FuseLossDefault(pred_name='model.logits.tabular', target_name='data.gt', + callable=F.cross_entropy, weight=1.0), + 'imaging_loss':FuseLossDefault(pred_name='model.logits.imaging', target_name='data.gt', + callable=F.cross_entropy, weight=1.0), + 'ensemble_loss':FuseLossDefault(pred_name='model.output.tabular_ensemble_average', target_name='data.gt', + callable=F.nll_loss, weight=1.0,reduction='sum'), + + } + metric_for_multimodal = { + 'multimodal_auc': FuseMetricAUC(pred_name='model.output.multimodal', target_name='data.gt'), + 'tabular_auc': FuseMetricAUC(pred_name='model.output.tabular', target_name='data.gt'), + 'imaging_auc': FuseMetricAUC(pred_name='model.output.imaging', target_name='data.gt'), + 'ensemble_auc':FuseMetricAUC(pred_name='model.output.tabular_ensemble_average', target_name='data.gt'), + } + ################################################ + + + if train_common_params['fusion_type'] == 'mono_tabular': + + train_common_params['model'] = FuseMultiModalityModel( + tabular_inputs=(('data.continuous', 1), ('data.categorical', 1),), + tabular_backbone=model_tabular, + tabular_heads=heads_for_multimodal['tabular_head'], + ) + train_common_params['loss'] = { + 'cls_loss': loss_for_multimodal['tabular_loss'], + } + train_common_params['metrics'] = { + 'auc': metric_for_multimodal['tabular_auc'], + } + + train_common_params['manager.learning_rate'] = 1e-4 + train_common_params['manager.weight_decay'] = 1e-4 + train_common_params['manager.momentum'] = 0.9 + train_common_params['manager.step_size'] = 150 + train_common_params['manager.gamma'] = 0.1 + train_common_params['optimizer'] = optim.SGD(train_common_params['model'].parameters(), + lr=train_common_params['manager.learning_rate'], + momentum=train_common_params['manager.momentum'], + weight_decay=train_common_params['manager.weight_decay']) + train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], + gamma=train_common_params['manager.gamma']) + + if train_common_params['fusion_type'] == 'mono_imaging': + + train_common_params['model'] = FuseMultiModalityModel( + imaging_inputs=(('data.image', 1),), + imaging_backbone=model_imaging, + imaging_heads=heads_for_multimodal['imaging_head'], + ) + train_common_params['loss'] = { + 'cls_loss': loss_for_multimodal['imaging_loss'], + } + train_common_params['metrics'] = { + 'auc': metric_for_multimodal['imaging_auc'], + } + train_common_params['manager.learning_rate'] = 1e-5 + train_common_params['manager.weight_decay'] = 0.001 + + train_common_params['optimizer'] = optim.Adam(train_common_params['model'].parameters(), lr=train_common_params['manager.learning_rate'], + weight_decay=train_common_params['manager.weight_decay']) + train_common_params['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(train_common_params['optimizer']) + + if train_common_params['fusion_type'] == 'late_fusion': + + train_common_params['model'] = FuseMultiModalityModel( + tabular_inputs=(('data.continuous', 1), ('data.categorical', 1),), + imaging_inputs=(('data.image', 1),), + tabular_backbone=model_tabular, + imaging_backbone=model_imaging, + multimodal_backbone=model_multimodel_concat, + imaging_heads=heads_for_multimodal['imaging_head'], + tabular_heads=heads_for_multimodal['tabular_head'], + multimodal_heads=heads_for_multimodal['multimodal_head'], + ) + + train_common_params['loss'] = { + 'cls_loss': loss_for_multimodal['multimodal_loss'], + } + train_common_params['metrics'] = { + 'auc': metric_for_multimodal['multimodal_auc'], + } + train_common_params['manager.learning_rate'] = 1e-4 + train_common_params['manager.weight_decay'] = 1e-4 + train_common_params['manager.momentum'] = 0.9 + train_common_params['manager.step_size'] = 150 + train_common_params['manager.gamma'] = 0.1 + train_common_params['optimizer'] = optim.SGD(train_common_params['model'].parameters(), + lr=train_common_params['manager.learning_rate'], + momentum=train_common_params['manager.momentum'], + weight_decay=train_common_params['manager.weight_decay']) + train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], + gamma=train_common_params['manager.gamma']) + + if train_common_params['fusion_type'] == 'ensemble': + + train_common_params['tabular_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/model_mg_radiologist_usa/mono_tabular/' + train_common_params['imaging_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/model_mg_radiologist_usa/mono_imaging_no_aug/' + train_common_params['model'] = FuseModelEnsemble(input_model_dirs=[train_common_params['tabular_dir'], + train_common_params['imaging_dir']]) + infer_common_params['model_dir'] = [train_common_params['tabular_dir'], + train_common_params['imaging_dir']] + + infer_common_params['output_keys'] = ['data.gt', + 'model.output.ensemble_output_0.tabular', + 'model.output.ensemble_output_1.tabular', + 'model.output.ensemble_output_1.imaging', + 'model.output.tabular_ensemble_average', + 'model.output.tabular_ensemble_majority_vote'] + + analyze_common_params['metrics'] = train_common_params['metrics'] = { + 'auc': metric_for_multimodal['ensemble_auc'], + } + + + # train_common_params['loss'] = { + # 'cls_loss': loss_for_multimodal['ensemble_loss'], + # + # } + # train_common_params['metrics'] = { + # 'auc': metric_for_multimodal['ensemble_auc'], + # } + # train_common_params['manager.learning_rate'] = 1e-5 + # train_common_params['manager.weight_decay'] = 0.001 + # + # train_common_params['optimizer'] = optim.Adam(train_common_params['model'].parameters(), lr=train_common_params['manager.learning_rate'], + # weight_decay=train_common_params['manager.weight_decay']) + # train_common_params['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(train_common_params['optimizer']) + + + #Mo:different parameter + if train_common_params['fusion_type'] == 'cotrastive': + train_common_params['model'] = FuseModelTabularImaging( + + continuous_tabular_input=(('data.continuous', 1),), + categorical_tabular_input=(('data.categorical', 1),), + imaging_inputs=(('data.image', 1),), + backbone_categorical_tabular=train_common_params['tabular_encoder_categorical'], + backbone_continuous_tabular=train_common_params['tabular_encoder_continuous'], + backbone_imaging=train_common_params['imaging_encoder'], + + ) + + train_common_params['loss'] = FuseLossMultimodalContrastiveLearning( + imaging_representations='model.imaging_representations', + tabular_representations='model.tabular_representations', + label='data.gt', + temperature=0.1, + alpha=0.5) + train_common_params['metrics'] = None + + return train_common_params,infer_common_params,analyze_common_params \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/runner.py b/fuse_examples/classification/multimodality/runner.py new file mode 100644 index 000000000..ca5aea080 --- /dev/null +++ b/fuse_examples/classification/multimodality/runner.py @@ -0,0 +1,391 @@ +""" + +(C) Copyright 2021 IBM Corp. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +Created on June 30, 2021 + +""" + +import os + +from fuse.utils.utils_debug import FuseUtilsDebug +from fuse.utils.utils_gpu import FuseUtilsGPU + +import logging + +import torch.optim as optim +from torch.utils.data.dataloader import DataLoader + +from fuse.utils.utils_logger import fuse_logger_start + +from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch + +from fuse.metrics.classification.metric_auc import FuseMetricAUC +from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy + +from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback +from fuse.managers.callbacks.callback_metric_statistics import FuseMetricStatisticsCallback +from fuse.managers.callbacks.callback_time_statistics import FuseTimeStatisticsCallback +from fuse.managers.manager_default import FuseManagerDefault + +from fuse_examples.classification.multimodality.mg_dataset_radiologist import MG_dataset +from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths +from fuse.models.backbones.backbone_inception_resnet_v2 import FuseBackboneInceptionResnetV2 +from fuse.models.backbones.backbone_mlp import FuseMultilayerPerceptronBackbone + + +from fuse.analyzer.analyzer_default import FuseAnalyzerDefault +from fuse.metrics.classification.metric_roc_curve import FuseMetricROCCurve + + +from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame +from fuse_examples.classification.multimodality.mg_dataset_radiologist import PostProcessing,tabular_feature_mg +from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters + + + + + +########################################## +# Debug modes +########################################## +mode = 'default' # Options: 'default', 'fast', 'debug', 'verbose', 'user'. See details in FuseUtilsDebug +debug = FuseUtilsDebug(mode) + +########################################## +# Train Common Params +########################################## +# ============ +# Data +# ============ +TRAIN_COMMON_PARAMS = {} + +TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8 +TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8 + +########################################## +# Dataset +########################################## +dataset_name = 'mg_radiologic' +root = '' +root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder +assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" +# Name of the experiment +experiment = 'late_fusion_non_visual' +# Path to cache data +cache_path = root_data+'/mg_radiologic/' + +paths = multimodal_paths(dataset_name, root_data, root, experiment, cache_path) +TRAIN_COMMON_PARAMS['paths'] = paths +TRAIN_COMMON_PARAMS['fusion_type'] = 'late_fusion' +###################################### +# Inference Common Params +###################################### +INFER_COMMON_PARAMS = {} +INFER_COMMON_PARAMS['infer_filename'] = os.path.join(TRAIN_COMMON_PARAMS['paths']['inference_dir'],'validation_set_infer.gz') +INFER_COMMON_PARAMS['checkpoint'] = 'best' # Fuse TIP: possible values are 'best', 'last' or epoch_index. +INFER_COMMON_PARAMS['data.train_num_workers'] = TRAIN_COMMON_PARAMS['data.train_num_workers'] + +# Analyze Common Params +###################################### +ANALYZE_COMMON_PARAMS = {} +ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename'] +ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(TRAIN_COMMON_PARAMS['paths']['inference_dir'],'all_metrics.txt') +ANALYZE_COMMON_PARAMS['num_workers'] = 4 +ANALYZE_COMMON_PARAMS['batch_size'] = 8 + +# =============== +# Manager - Train +# =============== +NUM_GPUS = 2 +TRAIN_COMMON_PARAMS['data.batch_size'] = 2 * NUM_GPUS +TRAIN_COMMON_PARAMS['manager.train_params'] = { + 'num_gpus': NUM_GPUS, + 'num_epochs': 300, + 'virtual_batch_size': 1, # number of batches in one virtual batch + 'start_saving_epochs': 10, # first epoch to start saving checkpoints from + 'gap_between_saving_epochs': 5, # number of epochs between saved checkpoint +} + +# best_epoch_source +# if an epoch values are the best so far, the epoch is saved as a checkpoint. +TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = { + 'source':'metrics.auc.macro_avg',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries + 'optimization': 'max', # can be either min/max + 'on_equal_values': 'better', + # can be either better/worse - whether to consider best epoch when values are equal +} + +TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None + + + +#define postprocessing function +features_dic = tabular_feature_mg() +TRAIN_COMMON_PARAMS['post_processing'] = PostProcessing(features_dic['continuous_clinical_feat'], + features_dic['categorical_clinical_feat'], + ['gt'], + features_dic['visual_feat'], + features_dic['non_visual_feat'], + use_imaging=False, use_non_imaging=True) + +#define processors +TRAIN_COMMON_PARAMS['imaging_processor'] = FuseMGInputProcessor +TRAIN_COMMON_PARAMS['tabular_processor'] = FuseProcessorDataFrame + +#define encoders +TRAIN_COMMON_PARAMS['imaging_feature_size'] = 384 +TRAIN_COMMON_PARAMS['tabular_feature_size'] = 256 +TRAIN_COMMON_PARAMS['tabular_encoder_categorical'] = FuseMultilayerPerceptronBackbone( + layers=[128, len(features_dic['categorical_clinical_feat'])], + mlp_input_size=len(features_dic['categorical_clinical_feat'])) +TRAIN_COMMON_PARAMS['tabular_encoder_continuous'] = FuseMultilayerPerceptronBackbone( + layers=[TRAIN_COMMON_PARAMS['tabular_feature_size']], + mlp_input_size=len(features_dic['categorical_clinical_feat'])+\ + len(features_dic['continuous_clinical_feat'])) + +TRAIN_COMMON_PARAMS['imaging_encoder'] = FuseBackboneInceptionResnetV2(input_channels_num=1) + + + +TRAIN_COMMON_PARAMS['dataset_func'] = MG_dataset( + tabular_filename=TRAIN_COMMON_PARAMS['paths']['tabular_filename'], + imaging_filename=TRAIN_COMMON_PARAMS['paths']['imaging_filename'], + train_val_test_filenames=TRAIN_COMMON_PARAMS['paths']['train_val_test_filenames'], + key_columns=TRAIN_COMMON_PARAMS['paths']['key_columns'], + label_key=TRAIN_COMMON_PARAMS['paths']['label_key'], + img_key=TRAIN_COMMON_PARAMS['paths']['img_key'], + sample_key=TRAIN_COMMON_PARAMS['paths']['sample_key'], + + imaging_processor=TRAIN_COMMON_PARAMS['imaging_processor'], + tabular_processor=TRAIN_COMMON_PARAMS['tabular_processor'], + + cache_dir=TRAIN_COMMON_PARAMS['paths']['cache_dir'], + reset_cache=False, + post_cache_processing_func=TRAIN_COMMON_PARAMS['post_processing'], + ) + +TRAIN_COMMON_PARAMS,INFER_COMMON_PARAMS,ANALYZE_COMMON_PARAMS = multimodal_parameters(TRAIN_COMMON_PARAMS,INFER_COMMON_PARAMS,ANALYZE_COMMON_PARAMS) + +################################# +# Train Template +################################# +def run_train(paths: dict, train_common_params: dict, reset_cache: bool): + # ============================================================================== + # Logger + # ============================================================================== + fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + + # Download data + # TBD + + lgr.info('\nFuse Train', {'attrs': ['bold', 'underline']}) + + lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'}) + lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'}) + + #### Train Data + lgr.info(f'Train Data:', {'attrs': 'bold'}) + #Mo: function that returns dataset datasetfun in train_params, kwargs + train_dataset, validation_dataset, _ = train_common_params['dataset_func'] + + + ## Create sampler + lgr.info(f'- Create sampler:') + sampler = FuseSamplerBalancedBatch(dataset=train_dataset, + balanced_class_name='data.gt', + num_balanced_classes=2, + batch_size=train_common_params['data.batch_size'], + use_dataset_cache=True) + + lgr.info(f'- Create sampler: Done') + + ## Create dataloader + train_dataloader = DataLoader(dataset=train_dataset, + shuffle=False, drop_last=False, + batch_sampler=sampler, collate_fn=train_dataset.collate_fn, + num_workers=train_common_params['data.train_num_workers']) + lgr.info(f'Train Data: Done', {'attrs': 'bold'}) + + #### Validation data + lgr.info(f'Validation Data:', {'attrs': 'bold'}) + + ## Create dataloader + validation_dataloader = DataLoader(dataset=validation_dataset, + shuffle=False, + drop_last=False, + batch_sampler=None, + batch_size=train_common_params['data.batch_size'], + num_workers=train_common_params['data.validation_num_workers'], + collate_fn=validation_dataset.collate_fn) + lgr.info(f'Validation Data: Done', {'attrs': 'bold'}) + + + # =================================================================== + # ============================================================================== + # Model + # ============================================================================== + lgr.info('Model:', {'attrs': 'bold'}) + + model = train_common_params['model'] + + lgr.info('Model: Done', {'attrs': 'bold'}) + + # ==================================================================================== + # Loss + # ==================================================================================== + + losses = train_common_params['loss'] + + # ==================================================================================== + # Metrics + # ==================================================================================== + + metrics=train_common_params['metrics'] + + # ===================================================================================== + # Callbacks + # ===================================================================================== + callbacks = [ + # default callbacks + FuseTensorboardCallback(model_dir=paths['model_dir']), # save statistics for tensorboard + FuseMetricStatisticsCallback(output_path=paths['model_dir'] + "/metrics.csv"), # save statistics in a csv file + FuseTimeStatisticsCallback(num_epochs=train_common_params['manager.train_params']['num_epochs'], + load_expected_part=0.1) # time profiler + ] + + # ===================================================================================== + # Manager - Train + # Create a manager, training objects and run a training process. + # ===================================================================================== + lgr.info('Train:', {'attrs': 'bold'}) + + # create optimizer + optimizer = TRAIN_COMMON_PARAMS['optimizer'] + # create scheduler + scheduler = TRAIN_COMMON_PARAMS['scheduler'] + + # train from scratch + manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir']) + # Providing the objects required for the training process. + manager.set_objects(net=model, + optimizer=optimizer, + losses=losses, + metrics=metrics, + best_epoch_source=train_common_params['manager.best_epoch_source'], + lr_scheduler=scheduler, + callbacks=callbacks, + train_params=train_common_params['manager.train_params'], + output_model_dir=paths['model_dir']) + + # Continue training + if train_common_params['manager.resume_checkpoint_filename'] is not None: + # Loading the checkpoint including model weights, learning rate, and epoch_index. + manager.load_checkpoint(checkpoint=train_common_params['manager.resume_checkpoint_filename'], mode='train', + values_to_resume=['net']) + # # Start training + manager.train(train_dataloader=train_dataloader, + validation_dataloader=validation_dataloader) + + lgr.info('Train: Done', {'attrs': 'bold'}) + + +###################################### +# Inference Template +###################################### +def run_infer(paths: dict, train_common_params: dict,infer_common_params: dict): + #### Logger + fuse_logger_start(output_path=paths['inference_dir'], console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + lgr.info('Fuse Inference', {'attrs': ['bold', 'underline']}) + lgr.info(f'infer_filename={os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])}', + {'color': 'magenta'}) + + # Create data source: + _, val_dataset, test_dataset = train_common_params['dataset_func'] + + ## Create dataloader + infer_dataloader = DataLoader(dataset=val_dataset, + shuffle=False, drop_last=False, + collate_fn=test_dataset.collate_fn, + num_workers=infer_common_params['data.train_num_workers']) + lgr.info(f'Test Data: Done', {'attrs': 'bold'}) + + #### Manager for inference + manager = FuseManagerDefault() + # extract just the global classification per sample and save to a file + output_columns = infer_common_params['output_keys'] + manager.infer(data_loader=infer_dataloader, + input_model_dir=infer_common_params['model_dir'], + checkpoint=infer_common_params['checkpoint'], + output_columns=output_columns, + output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])) + + ###################################### + + + +###################################### +# Analyze Template +###################################### +def run_analyze(paths: dict, analyze_common_params: dict): + fuse_logger_start(output_path=None, console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']}) + + # metrics + metrics = analyze_common_params['metrics'] + + + # create analyzer + analyzer = FuseAnalyzerDefault() + + # run + # FIXME: simplify analyze interface for this case + results = analyzer.analyze(gt_processors={}, + data_pickle_filename=os.path.join(paths["inference_dir"], + analyze_common_params["infer_filename"]), + metrics=metrics, + output_filename=analyze_common_params['output_filename']) + + return results + + +###################################### +# Run +###################################### + + +if __name__ == "__main__": + # allocate gpus + if NUM_GPUS == 0: + TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu' + # uncomment if you want to use specific gpus instead of automatically looking for free ones + force_gpus = None # [0] + FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus) + + RUNNING_MODES = ['train']#['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze' + + + paths = TRAIN_COMMON_PARAMS['paths'] + # train + if 'train' in RUNNING_MODES: + run_train(paths=paths, train_common_params=TRAIN_COMMON_PARAMS, reset_cache=False) + + # infer + if 'infer' in RUNNING_MODES: + run_infer(paths=paths, train_common_params=TRAIN_COMMON_PARAMS,infer_common_params=INFER_COMMON_PARAMS) + # + # analyze + if 'analyze' in RUNNING_MODES: + run_analyze(paths=paths, analyze_common_params=ANALYZE_COMMON_PARAMS) From 39377b9a399a883498c5bb13d9bbe67f04bc2092 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Sun, 1 May 2022 14:30:17 +0300 Subject: [PATCH 2/7] refactoring example --- .../classification/multimodality/dataset.py | 131 ++++++++++-------- .../multimodality/mg_dataset.py | 4 +- ...=> mg_dataset_clinical_and_annotations.py} | 112 ++++++--------- .../classification/multimodality/runner.py | 39 ++---- 4 files changed, 132 insertions(+), 154 deletions(-) rename fuse_examples/classification/multimodality/{mg_dataset_radiologist.py => mg_dataset_clinical_and_annotations.py} (77%) diff --git a/fuse_examples/classification/multimodality/dataset.py b/fuse_examples/classification/multimodality/dataset.py index 4cf96ce47..99ada0af9 100644 --- a/fuse_examples/classification/multimodality/dataset.py +++ b/fuse_examples/classification/multimodality/dataset.py @@ -1,11 +1,7 @@ -import sys -from typing import Callable, Optional +from typing import Callable, Optional, Tuple, Any, Iterable import logging import pandas as pd -import pydicom -import os, glob -from pathlib import Path -from typing import Tuple +from typing import List from fuse.data.visualizer.visualizer_default import FuseVisualizerDefault from fuse.data.augmentor.augmentor_default import FuseAugmentorDefault @@ -13,27 +9,16 @@ from fuse.data.dataset.dataset_default import FuseDatasetDefault from fuse.data.dataset.dataset_generator import FuseDatasetGenerator from fuse.data.data_source.data_source_default import FuseDataSourceDefault - -from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform -from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt -from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool - +from fuse.data.processor.processor_base import FuseProcessorBase +from fuse.utils.rand.param_sampler import Uniform, RandInt, RandBool from fuse_examples.classification.multimodality.input_processor import ImagingTabularProcessor - - -def IMAGING_dataset(): +def imaging_augmentation()-> Iterable[Any]: """ - Creates Fuse Dataset object for training, validation and test - :param data_dir: dataset root path - :param data_misc_dir path to save misc files to be used later - :param cache_dir: Optional, name of the cache folder - :param reset_cache: Optional,specifies if we want to clear the cache first - :param post_cache_processing_func: Optional, function run post cache processing - :return: training, validation and test FuseDatasetDefault objects + :return: augmentation_pipeline iterator """ augmentation_pipeline = [ [ @@ -58,40 +43,59 @@ def IMAGING_dataset(): ], ] - - # Create data augmentation (optional) - augmentor = FuseAugmentorDefault( - augmentation_pipeline=augmentation_pipeline) - - - + augmentor = FuseAugmentorDefault(augmentation_pipeline=augmentation_pipeline) return augmentor -def TABULAR_dataset(tabular_processor,df,tabular_features,sample_key): - tabular_features.remove(sample_key) - tabular_processor = tabular_processor(data=df, - sample_desc_column=sample_key, - columns_to_extract=tabular_features + [sample_key], - columns_to_tensor=tabular_features) - return tabular_processor - - -def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:str,img_key:str,tabular_features_lst: list,sample_key: str, - cache_dir: str = 'cache', reset_cache: bool = False, - post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]: +# def tabular_dataset(tabular_processor,df,tabular_features,sample_key): +# +# +# tabular_features.remove(sample_key) +# tabular_processor = tabular_processor(data=df, +# sample_desc_column=sample_key, +# columns_to_extract=tabular_features + [sample_key], +# columns_to_tensor=tabular_features) +# return tabular_processor + + +def imaging_tabular_dataset(data_split: List[pd.Dataframe], + imaging_processor: FuseProcessorBase, + tabular_processor: FuseProcessorBase, + label_key:str, + img_key:str, + sample_key: str, + tabular_features_lst: list, + cache_dir: str = 'cache', + reset_cache: bool = False, + post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]: + """ + Creates Fuse Dataset object for training, validation and test + :param data_split: A list of train, validation and test dataframes + :param imaging_processor: Imaging data generator + :param tabular_processor: Tabular data generator + :param label_key Name of label to use from dataframe + :param img_key Name of image path column + :param sample_key Name of sample id + :param tabular_features_lst a list of tabular keys to use + :param cache_dir: Optional, name of the cache folder + :param reset_cache: Optional,specifies if we want to clear the cache first + :param post_cache_processing_func: Optional, function run post cache processing + :return: training, validation and test FuseDatasetDefault objects + """ lgr = logging.getLogger('Fuse') - if isinstance(df,list): - df_train = df[0] - if len(df)>1: - df_val = df[1] - if len(df)>2: - df_test = df[2] + if isinstance(data_split,list): + df_train = data_split[0] + if len(data_split)>1: + df_val = data_split[1] + if len(data_split)>2: + df_test = data_split[2] + else: + raise Exception(f'current version supports train/val/test data division') #---------------------------------------------- # -----Datasource @@ -100,27 +104,46 @@ def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:s test_data_source = FuseDataSourceDefault(input_source=df_test) # ---------------------------------------------- + + tabular_features_lst.remove(sample_key) + # tabular_processor = tabular_processor(data=df, + # sample_desc_column=sample_key, + # columns_to_extract=tabular_features_lst + [sample_key], + # columns_to_tensor=tabular_features_lst) + # -----Data-processors img_clinical_processor_train = ImagingTabularProcessor(data=df_train, label=label_key, img_key = img_key, image_processor=imaging_processor(''), - tabular_processor= \ - TABULAR_dataset(tabular_processor,df_train,tabular_features_lst.copy(),sample_key)) + tabular_processor=tabular_processor(data=df_train, + sample_desc_column=sample_key, + columns_to_extract=tabular_features_lst + [sample_key], + columns_to_tensor=tabular_features_lst) + # tabular_dataset(tabular_processor,df_train,tabular_features_lst.copy(),sample_key) + ) img_clinical_processor_val = ImagingTabularProcessor(data=df_val, label=label_key, img_key=img_key, image_processor=imaging_processor(''), - tabular_processor=\ - TABULAR_dataset(tabular_processor,df_val,tabular_features_lst.copy(),sample_key)) + tabular_processor=tabular_processor(data=df_val, + sample_desc_column=sample_key, + columns_to_extract=tabular_features_lst + [sample_key], + columns_to_tensor=tabular_features_lst) + # tabular_dataset(tabular_processor,df_val,tabular_features_lst.copy(),sample_key) + ) img_clinical_processor_test = ImagingTabularProcessor(data=df_test, label=label_key, img_key=img_key, image_processor=imaging_processor(''), - tabular_processor= \ - TABULAR_dataset(tabular_processor,df_test,tabular_features_lst.copy(),sample_key)) + tabular_processor=tabular_processor(data=df_test, + sample_desc_column=sample_key, + columns_to_extract=tabular_features_lst + [sample_key], + columns_to_tensor=tabular_features_lst) + # tabular_dataset(tabular_processor,df_test,tabular_features_lst.copy(),sample_key) + ) @@ -132,7 +155,7 @@ def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:s train_dataset = FuseDatasetGenerator(cache_dest=cache_dir, data_source=train_data_source, processor=img_clinical_processor_train, - augmentor=IMAGING_dataset(), + augmentor=imaging_augmentation(), visualizer=visualiser, post_processing_func=post_cache_processing_func,) @@ -151,10 +174,8 @@ def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:s visualizer=visualiser, post_processing_func=post_cache_processing_func,) - # ---------------------------------------------- # ------ Cache - # create cache train_dataset.create(reset_cache=reset_cache) # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading validation_dataset.create() # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading diff --git a/fuse_examples/classification/multimodality/mg_dataset.py b/fuse_examples/classification/multimodality/mg_dataset.py index adcb7a1ff..7bce70174 100644 --- a/fuse_examples/classification/multimodality/mg_dataset.py +++ b/fuse_examples/classification/multimodality/mg_dataset.py @@ -11,7 +11,7 @@ # from autogluon.tabular import TabularPredictor -from fuse_examples.classification.multimodality.dataset import IMAGING_TABULAR_dataset +from fuse_examples.classification.multimodality.dataset import imaging_tabular_dataset from fuse.data.dataset.dataset_default import FuseDatasetDefault from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor @@ -383,7 +383,7 @@ def MG_dataset(tabular_filename:str, features_list = list(tabular_columns) [features_list.remove(x) for x in key_columns] - train_dataset, validation_dataset, test_dataset = IMAGING_TABULAR_dataset( + train_dataset, validation_dataset, test_dataset = imaging_tabular_dataset( df=[train_set, val_set, test_set], imaging_processor=imaging_processor, tabular_processor=tabular_processor, diff --git a/fuse_examples/classification/multimodality/mg_dataset_radiologist.py b/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py similarity index 77% rename from fuse_examples/classification/multimodality/mg_dataset_radiologist.py rename to fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py index 04eb1a8fa..27cb3ec3c 100644 --- a/fuse_examples/classification/multimodality/mg_dataset_radiologist.py +++ b/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py @@ -1,47 +1,56 @@ import pandas as pd import os -from typing import Callable, Optional -from typing import Tuple +from typing import Callable, Optional,Tuple,Dict, List +import torch -# from MedicalAnalyticsCore.DatabaseUtils.selected_studies_queries import get_annotations_and_findings -# from MedicalAnalyticsCore.DatabaseUtils.tableResolver import TableResolver -# from MedicalAnalyticsCore.DatabaseUtils.connection import create_homer_engine, Connection -# from MedicalAnalyticsCore.DatabaseUtils import tableNames -# from MedicalAnalyticsCore.DatabaseUtils import db_utils as db # from autogluon.tabular import TabularPredictor -from fuse_examples.classification.multimodality.dataset import IMAGING_TABULAR_dataset -from fuse.data.dataset.dataset_default import FuseDatasetDefault - -from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse_examples.classification.multimodality.dataset import imaging_tabular_dataset +from fuse_examples.classification.cmmd.input_processor import FuseMGInputProcessor from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame - - -from typing import Dict, List -import torch +from fuse.data.dataset.dataset_default import FuseDatasetDefault +from fuse.data.processor.processor_base import FuseProcessorBase from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict class PostProcessing: - def __init__(self, continuous_tabular_features_lst: List, + """ + Post processing for the tabular data. + In this dataset, some of the features are continuous and the other are categorical + In additions there are two types of features - + 1. Objective features 2. Annotated features (by expert) based on the imaging + :param continuous_tabular_features_lst: columns in data that are continuous + :param categorical_tabular_features_lst: columns in data that are categorical + :param label_lst: label columns + :param annotated_features_lst: columns in data for annotated features (by expert) based on the imaging + :param non_annotated_features_lst: columns in data for non-annotated features (by expert) based on the imaging + :param use_annotated: + :param use_non_annotated + + """ + def __init__(self, + continuous_tabular_features_lst: List, categorical_tabular_features_lst: List, label_lst: List, - imaging_features_lst: List, - non_imaging_features_lst: List, - use_imaging: bool, - use_non_imaging: bool): + annotated_features_lst: List, + non_annotated_features_lst: List, + use_annotated: bool, + use_non_annotated: bool): + self.continuous_tabular_features_lst = continuous_tabular_features_lst self.categorical_tabular_features_lst = categorical_tabular_features_lst self.label_lst = label_lst - self.imaging_features_lst = imaging_features_lst - self.non_imaging_features_lst = non_imaging_features_lst - self.use_imaging = use_imaging - self.use_non_imaging = use_non_imaging + self.annotated_features_lst = annotated_features_lst + self.non_annotated_features_lst = non_annotated_features_lst + self.use_annotated = use_annotated + self.use_non_annotated = use_non_annotated def __call__(self, batch_dict: Dict) -> Dict: - if not self.use_imaging and not self.use_non_imaging: + + if not self.use_annotated and not self.use_non_annotated: raise ValueError('No features are in use') - mask_list = self.use_imaging * self.imaging_features_lst + self.use_non_imaging * self.non_imaging_features_lst + + mask_list = self.use_annotated * self.annotated_features_lst + self.use_non_annotated * self.non_annotated_features_lst mask_continuous = torch.zeros(len( self.continuous_tabular_features_lst)) for i in range(len(mask_list)): try: @@ -78,8 +87,6 @@ def __call__(self, batch_dict: Dict) -> Dict: return batch_dict -# feature selection univarient analysis - #-------------------Tabular def get_selected_features_mg(data,features_mode,key_columns): features_dict = tabular_feature_mg() @@ -158,40 +165,7 @@ def imaging_mg(imaging_filename,key_columns): if os.path.exists(imaging_filename): df = pd.read_csv(imaging_filename) else: - REVISION_DATE = '20200915' - TableResolver().set_revision(REVISION_DATE) - revision = {'prefix': 'sentara', 'suffix': REVISION_DATE} - engine = Connection().get_engine() - - df_with_findings = get_annotations_and_findings(engine, revision, - exam_types=['MG'], viewpoints=None, # ['CC','MLO'], \ - include_findings=True, remove_invalids=True, - remove_heldout=False, \ - remove_excluded=False, remove_less_than_4views=False, \ - load_from_file=False, save_to_file=False) - - # dicom_table = db.get_table_as_dataframe(engine, tableNames.get_dicom_tags_table_name(revision)) - # study_statuses = db.get_table_as_dataframe(engine, tableNames.get_study_statuses_table_name(revision)) - my_providers = ['sentara'] - df = df_with_findings.loc[df_with_findings['provider'].isin(my_providers)] - # fixing assymetry - asymmetries = ['asymmetry', 'developing asymmetry', 'focal asymmetry', 'global asymmetry'] - df['is_asymmetry'] = df['pathology'].isin(asymmetries) - df['is_Breast_Assymetry'] = df['type'].isin(['Breast Assymetry']) - df.loc[df['is_asymmetry'], 'pathology'] = df[df['is_asymmetry']]['biopsy_outcome'] - df.loc[df['is_Breast_Assymetry'], 'pathology'] = df[df['is_Breast_Assymetry']]['biopsy_outcome'] - # remove duble xmls - aa_unsorted = df - aa_unsorted.sort_values('xml_url', ascending=False, inplace=True) - xml_url_to_keep = aa_unsorted.groupby(['image_id'])['xml_url'].transform('first') - df = aa_unsorted[aa_unsorted['xml_url'] == xml_url_to_keep] - remove_from_pathology = ['undefined', 'not_applicable', 'Undefined', 'extracapsular rupture of breast implant', - 'intracapsular rupture of breast implant'] - is_pathology = ~df.pathology.isnull() & ~df.pathology.isin(remove_from_pathology) - is_digital = df.image_source == 'Digital' - is_biopsy = df.finding_biopsy.isin(['negative', 'negative high risk', 'positive']) - df = df[(is_digital) & (is_pathology) & (is_biopsy)] - df.to_csv(imaging_filename) + df = data_curation(imaging_filename) df1 = df.groupby(key_columns)[img_sample_column].apply(lambda x: list(map(str, x))).reset_index() df2 = df.groupby(key_columns)[label_column].apply(lambda x: list(map(str, x))).reset_index() @@ -199,7 +173,6 @@ def imaging_mg(imaging_filename,key_columns): return pd.merge(df1,df2,on=key_columns) - #------------------Imaging+Tabular def merge_datasets(tabular_filename,imaging_filename,key_columns): tabular_data = tabular_mg(tabular_filename, key_columns) @@ -228,14 +201,13 @@ def apply_gluon_baseline(train_set,test_set,label,save_path): y_pred = predictor.predict_proba(test_data) perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) -#MO: thinkabout specific name -def MG_dataset(tabular_filename:str, +def mg_clinical_annotations_dataset( + tabular_filename:str, imaging_filename:str, train_val_test_filenames:list, - #Mo: internal parameters - imaging_processor, - tabular_processor, + imaging_processor: FuseProcessorBase,, + tabular_processor: FuseProcessorBase,, key_columns:list, label_key:str, @@ -259,7 +231,7 @@ def MG_dataset(tabular_filename:str, features_list = list(tabular_columns) [features_list.remove(x) for x in key_columns] - train_dataset, validation_dataset, test_dataset = IMAGING_TABULAR_dataset( + train_dataset, validation_dataset, test_dataset = imaging_tabular_dataset( df=[train_set, val_set, test_set], imaging_processor=imaging_processor, tabular_processor=tabular_processor, @@ -290,7 +262,7 @@ def MG_dataset(tabular_filename:str, label_column = 'finding_biopsy' img_sample_column = 'dcm_url' train_dataset, validation_dataset, test_dataset = \ - MG_dataset(tabular_filename=tabular_filename, + mg_clinical_annotations_dataset(tabular_filename=tabular_filename, imaging_filename=imaging_filename, train_val_test_filenames=train_val_test_filenames, key_columns=key_columns, diff --git a/fuse_examples/classification/multimodality/runner.py b/fuse_examples/classification/multimodality/runner.py index ca5aea080..bd25728d6 100644 --- a/fuse_examples/classification/multimodality/runner.py +++ b/fuse_examples/classification/multimodality/runner.py @@ -15,42 +15,27 @@ """ import os - -from fuse.utils.utils_debug import FuseUtilsDebug -from fuse.utils.utils_gpu import FuseUtilsGPU - import logging - -import torch.optim as optim from torch.utils.data.dataloader import DataLoader +import fuse.utils.gpu as FuseUtilsGPU +from fuse.utils.utils_debug import FuseUtilsDebug from fuse.utils.utils_logger import fuse_logger_start - -from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch - -from fuse.metrics.classification.metric_auc import FuseMetricAUC -from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy - from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback from fuse.managers.callbacks.callback_metric_statistics import FuseMetricStatisticsCallback from fuse.managers.callbacks.callback_time_statistics import FuseTimeStatisticsCallback from fuse.managers.manager_default import FuseManagerDefault - -from fuse_examples.classification.multimodality.mg_dataset_radiologist import MG_dataset -from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths from fuse.models.backbones.backbone_inception_resnet_v2 import FuseBackboneInceptionResnetV2 from fuse.models.backbones.backbone_mlp import FuseMultilayerPerceptronBackbone - - -from fuse.analyzer.analyzer_default import FuseAnalyzerDefault -from fuse.metrics.classification.metric_roc_curve import FuseMetricROCCurve - - -from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch +from fuse.eval.evaluator import EvaluatorDefault from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame -from fuse_examples.classification.multimodality.mg_dataset_radiologist import PostProcessing,tabular_feature_mg -from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters +from fuse_examples.classification.cmmd.input_processor import FuseMGInputProcessor +from fuse_examples.classification.multimodality.mg_dataset_clinical_and_annotations import PostProcessing,tabular_feature_mg +from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters +from fuse_examples.classification.multimodality.mg_dataset_clinical_and_annotations import mg_clinical_annotations_dataset +from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths @@ -157,7 +142,7 @@ -TRAIN_COMMON_PARAMS['dataset_func'] = MG_dataset( +TRAIN_COMMON_PARAMS['dataset_func'] = mg_clinical_annotations_dataset( tabular_filename=TRAIN_COMMON_PARAMS['paths']['tabular_filename'], imaging_filename=TRAIN_COMMON_PARAMS['paths']['imaging_filename'], train_val_test_filenames=TRAIN_COMMON_PARAMS['paths']['train_val_test_filenames'], @@ -348,11 +333,11 @@ def run_analyze(paths: dict, analyze_common_params: dict): # create analyzer - analyzer = FuseAnalyzerDefault() + analyzer = EvaluatorDefault() # run # FIXME: simplify analyze interface for this case - results = analyzer.analyze(gt_processors={}, + results = analyzer.eval(gt_processors={}, data_pickle_filename=os.path.join(paths["inference_dir"], analyze_common_params["infer_filename"]), metrics=metrics, From f84b336960734b1eb709a8a65361e9bf115bc638 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Tue, 3 May 2022 10:35:04 +0300 Subject: [PATCH 3/7] refactoring post review 1 --- .../multimodality/data_curation.py | 58 ++++ .../classification/multimodality/dataset.py | 2 +- .../mg_dataset_clinical_and_annotations.py | 33 +-- .../multimodality/model_tabular_imaging.py | 277 ++++++++---------- .../multimodality/multimodel_parameters.py | 43 +-- .../classification/multimodality/runner.py | 19 +- 6 files changed, 213 insertions(+), 219 deletions(-) create mode 100644 fuse_examples/classification/multimodality/data_curation.py diff --git a/fuse_examples/classification/multimodality/data_curation.py b/fuse_examples/classification/multimodality/data_curation.py new file mode 100644 index 000000000..26f78efa3 --- /dev/null +++ b/fuse_examples/classification/multimodality/data_curation.py @@ -0,0 +1,58 @@ +from MedicalAnalyticsCore.DatabaseUtils.selected_studies_queries import get_annotations_and_findings +from MedicalAnalyticsCore.DatabaseUtils.tableResolver import TableResolver +from MedicalAnalyticsCore.DatabaseUtils.connection import create_homer_engine, Connection +from MedicalAnalyticsCore.DatabaseUtils import tableNames +from MedicalAnalyticsCore.DatabaseUtils import db_utils as db + + +def mg_data_curation(imaging_filename): + REVISION_DATE = '20200915' + TableResolver().set_revision(REVISION_DATE) + revision = {'prefix': 'sentara', 'suffix': REVISION_DATE} + engine = Connection().get_engine() + + df_with_findings = get_annotations_and_findings(engine, revision, + exam_types=['MG'], viewpoints=None, # ['CC','MLO'], \ + include_findings=True, remove_invalids=True, + remove_heldout=False, \ + remove_excluded=False, remove_less_than_4views=False, \ + load_from_file=False, save_to_file=False) + + my_providers = ['sentara'] + df = df_with_findings.loc[df_with_findings['provider'].isin(my_providers)] + # fixing assymetry + asymmetries = ['asymmetry', 'developing asymmetry', 'focal asymmetry', 'global asymmetry'] + df['is_asymmetry'] = df['pathology'].isin(asymmetries) + df['is_Breast_Assymetry'] = df['type'].isin(['Breast Assymetry']) + df.loc[df['is_asymmetry'], 'pathology'] = df[df['is_asymmetry']]['biopsy_outcome'] + df.loc[df['is_Breast_Assymetry'], 'pathology'] = df[df['is_Breast_Assymetry']]['biopsy_outcome'] + # remove duble xmls + aa_unsorted = df + aa_unsorted.sort_values('xml_url', ascending=False, inplace=True) + xml_url_to_keep = aa_unsorted.groupby(['image_id'])['xml_url'].transform('first') + df = aa_unsorted[aa_unsorted['xml_url'] == xml_url_to_keep] + remove_from_pathology = ['undefined', 'not_applicable', 'Undefined', 'extracapsular rupture of breast implant', + 'intracapsular rupture of breast implant'] + is_pathology = ~df.pathology.isnull() & ~df.pathology.isin(remove_from_pathology) + is_digital = df.image_source == 'Digital' + is_biopsy = df.finding_biopsy.isin(['negative', 'negative high risk', 'positive']) + df = df[(is_digital) & (is_pathology) & (is_biopsy)] + df.to_csv(imaging_filename) + +#------------------Baseline +def apply_gluon_baseline(train_set,test_set,label,save_path): + from autogluon.tabular import TabularPredictor + + predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) + results = predictor.fit_summary(show_plot=True) + + # Inference time: + y_test = test_set[label] + test_data = test_set.drop(labels=[label], + axis=1) # delete labels from test data since we wouldn't have them in practice + print(test_data.head()) + + predictor = TabularPredictor.load( + save_path) + y_pred = predictor.predict_proba(test_data) + perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/dataset.py b/fuse_examples/classification/multimodality/dataset.py index 99ada0af9..eacb86f90 100644 --- a/fuse_examples/classification/multimodality/dataset.py +++ b/fuse_examples/classification/multimodality/dataset.py @@ -60,7 +60,7 @@ def imaging_augmentation()-> Iterable[Any]: # return tabular_processor -def imaging_tabular_dataset(data_split: List[pd.Dataframe], +def imaging_tabular_dataset(data_split: List[pd.DataFrame], imaging_processor: FuseProcessorBase, tabular_processor: FuseProcessorBase, label_key:str, diff --git a/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py b/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py index 27cb3ec3c..a2c0981fa 100644 --- a/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py +++ b/fuse_examples/classification/multimodality/mg_dataset_clinical_and_annotations.py @@ -3,9 +3,6 @@ from typing import Callable, Optional,Tuple,Dict, List import torch - - -# from autogluon.tabular import TabularPredictor from fuse_examples.classification.multimodality.dataset import imaging_tabular_dataset from fuse_examples.classification.cmmd.input_processor import FuseMGInputProcessor from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame @@ -123,7 +120,7 @@ def tabular_feature_mg(): 'race_1', 'race_2', 'race_3', 'race_4', 'race_5', 'race_6', 'race_7', 'race_8', 'race_9', 'race_10', 'max_prev_birad_class_0', 'max_prev_birad_class_1', 'max_prev_birad_class_2', 'max_prev_birad_class_3'] #63 categorical clinical features - features_dict['visual_feat'] = ['findings_size', 'findings_x_max', 'findings_y_max', 'side', 'is_distortions', + features_dict['annotated_feat'] = ['findings_size', 'findings_x_max', 'findings_y_max', 'side', 'is_distortions', 'is_spiculations', 'is_susp_calcifications', 'breast_density_1', 'breast_density_2', 'breast_density_3', 'breast_density_4', 'final_side_birad_0', 'final_side_birad_1', @@ -138,7 +135,7 @@ def tabular_feature_mg(): 'type_0', 'type_1', 'type_2', 'type_3', 'type_4', 'type_5', 'type_6', 'max_prev_birad_class_0', 'max_prev_birad_class_1', 'max_prev_birad_class_2', 'max_prev_birad_class_3'] - features_dict['non_visual_feat'] = ['DistanceSourceToPatient', 'DistanceSourceToDetector', 'x_pixel_spacing', + features_dict['non_annotated_feat'] = ['DistanceSourceToPatient', 'DistanceSourceToDetector', 'x_pixel_spacing', 'XRayTubeCurrent', 'CompressionForce', 'exposure_time', 'KVP', 'body_part_thickness', 'RelativeXRayExposure', @@ -146,7 +143,6 @@ def tabular_feature_mg(): 'race_6', 'race_7', 'race_8', 'race_9', 'race_10'] - return features_dict @@ -165,7 +161,8 @@ def imaging_mg(imaging_filename,key_columns): if os.path.exists(imaging_filename): df = pd.read_csv(imaging_filename) else: - df = data_curation(imaging_filename) + from fuse_examples.classification.multimodality.data_curation import mg_data_curation + df = mg_data_curation(imaging_filename) df1 = df.groupby(key_columns)[img_sample_column].apply(lambda x: list(map(str, x))).reset_index() df2 = df.groupby(key_columns)[label_column].apply(lambda x: list(map(str, x))).reset_index() @@ -184,30 +181,14 @@ def merge_datasets(tabular_filename,imaging_filename,key_columns): dataset = pd.merge(tabular_data, imaging_data, on=key_columns, how='inner') return dataset,tabular_columns,imaging_columns -#------------------Baseline -def apply_gluon_baseline(train_set,test_set,label,save_path): - - predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) - results = predictor.fit_summary(show_plot=True) - - # Inference time: - y_test = test_set[label] - test_data = test_set.drop(labels=[label], - axis=1) # delete labels from test data since we wouldn't have them in practice - print(test_data.head()) - - predictor = TabularPredictor.load( - save_path) - y_pred = predictor.predict_proba(test_data) - perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) def mg_clinical_annotations_dataset( tabular_filename:str, imaging_filename:str, train_val_test_filenames:list, - imaging_processor: FuseProcessorBase,, - tabular_processor: FuseProcessorBase,, + imaging_processor: FuseProcessorBase, + tabular_processor: FuseProcessorBase, key_columns:list, label_key:str, @@ -232,7 +213,7 @@ def mg_clinical_annotations_dataset( features_list = list(tabular_columns) [features_list.remove(x) for x in key_columns] train_dataset, validation_dataset, test_dataset = imaging_tabular_dataset( - df=[train_set, val_set, test_set], + data_split=[train_set, val_set, test_set], imaging_processor=imaging_processor, tabular_processor=tabular_processor, label_key=label_key, diff --git a/fuse_examples/classification/multimodality/model_tabular_imaging.py b/fuse_examples/classification/multimodality/model_tabular_imaging.py index 08ee2e8a8..cdeae242a 100644 --- a/fuse_examples/classification/multimodality/model_tabular_imaging.py +++ b/fuse_examples/classification/multimodality/model_tabular_imaging.py @@ -6,163 +6,93 @@ from typing import Dict, Tuple, Sequence from fuse.models.backbones.backbone_inception_resnet_v2 import FuseBackboneInceptionResnetV2 +class project_imaging(nn.Module): - -# class Fusefchead(nn.Module): -# def __init__(self, -# cat_representations: Sequence[Tuple[str, int]] = (('model.cat_representations', 1),), -# backbone: FuseMultilayerPerceptronBackbone = FuseMultilayerPerceptronBackbone( -# layers=[2], mlp_input_size=512), -# ) -> None: -# super().__init__() -# -# self.cat_representations = cat_representations -# self.backbone = backbone -# def forward(self, batch_dict: Dict) -> Dict: -# cat_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.cat_representations[0][0]) -# logits = self.backbone(cat_representations) -# preds = F.softmax(logits, dim=1) -# -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.logits', logits) -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.output', preds) -# -# return batch_dict - -# class FuseModelImagingTabularHead(torch.nn.Module): -# def __init__(self, -# backbone: torch.nn.Module, -# heads: Sequence[torch.nn.Module], -# ) -> None: -# super().__init__() -# self.backbone = backbone -# self.heads = torch.nn.ModuleList(heads) -# self.add_module('heads', self.heads) -# -# def forward(self, batch_dict: Dict) -> Dict: -# representations_batch_dict = self.backbone.forward(batch_dict) -# imaging_representations = FuseUtilsHierarchicalDict.get(representations_batch_dict, 'imaging_representations') -# tabular_representations = FuseUtilsHierarchicalDict.get(representations_batch_dict, 'tabular_representations') -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_representations', imaging_representations) -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_representations', tabular_representations) -# if len(imaging_representations.shape)<2: -# imaging_representations = imaging_representations.unsqueeze(dim=0) -# if len(tabular_representations.shape)<2: -# tabular_representations = tabular_representations.unsqueeze(dim=0) -# cat_representations = torch.cat((tabular_representations, imaging_representations), 1) -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.cat_representations', cat_representations) -# for head in self.heads: -# batch_dict = head.forward(batch_dict) -# return batch_dict['model'] - -# class Fusesoftmax(nn.Module): -# def __init__(self, -# logits: Sequence[Tuple[str, int]] = (('model.features', 1),), -# ) -> None: -# super().__init__() -# -# self.logits = logits -# def forward(self, batch_dict: Dict) -> Dict: -# logits = FuseUtilsHierarchicalDict.get(batch_dict, self.logits[0][0]) -# preds = F.softmax(logits, dim=1) -# -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.logits', logits) -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.output', preds) -# -# return batch_dict - -# class FuseModelTabularImaging(torch.nn.Module): -# def __init__(self, -# continuous_tabular_input: Tuple[Tuple[str, int], ...], -# categorical_tabular_input: Tuple[Tuple[str, int], ...], -# imaging_inputs: Tuple[Tuple[str, int], ...], -# backbone_categorical_tabular: torch.nn.Module = None, -# backbone_continuous_tabular: torch.nn.Module = None, -# backbone_imaging: torch.nn.Module = None, -# projection_imaging: nn.Conv2d = nn.Conv2d(384, 256, kernel_size=1, stride=1) -# ) -> None: -# super().__init__() -# self.continuous_tabular_input = continuous_tabular_input -# self.categorical_tabular_input = categorical_tabular_input -# self.imaging_inputs = imaging_inputs -# self.backbone_categorical_tabular = backbone_categorical_tabular -# self.backbone_continuous_tabular = backbone_continuous_tabular -# self.backbone_imaging = backbone_imaging -# self.projection_imaging = projection_imaging -# -# def forward(self, batch_dict: Dict) -> Dict: -# -# #tabular encoder -# categorical_input = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) -# categorical_embeddings = self.backbone_categorical_tabular(categorical_input) -# continuous_input = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) -# input_cat = torch.cat((categorical_embeddings, continuous_input), 1) -# tabular_representations = self.backbone_continuous_tabular(input_cat) #dim 256 -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_representations', tabular_representations) -# -# #imaging encoder -# imaging_input = FuseUtilsHierarchicalDict.get(batch_dict, self.imaging_inputs[0][0]) -# backbone_imaging_features = self.backbone_imaging.forward(imaging_input) -# res = F.max_pool2d(backbone_imaging_features, kernel_size=backbone_imaging_features.shape[2:]) -# imaging_representations = self.projection_imaging.forward(res) -# imaging_representations = torch.squeeze(imaging_representations) -# FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_representations', imaging_representations) -# -# return batch_dict['model'] - - - -# concat model -class TabularImagingConcat(nn.Module): - def __init__(self, pooling='max',projection_imaging: nn.Conv2d = nn.Conv2d(384, 256, kernel_size=1, stride=1)): + def __init__(self, pooling='max', dim='2d', projection_imaging: nn.Module = None): super().__init__() assert pooling in ('max', 'avg') + assert dim in ('2d', '3d') self.pooling = pooling + self.dim = dim self.projection_imaging = projection_imaging - def fix_imaging(self,imaging_features): + def forward(self, imaging_features): if self.pooling == 'max': - imaging_features = F.max_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + if self.dim == '2d': + imaging_features = F.max_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + else: + imaging_features = F.max_pool3d(imaging_features, kernel_size=imaging_features.shape[2:]) elif self.pooling == 'avg': - imaging_features = F.avg_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + if self.dim == '2d': + imaging_features = F.avg_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) + else: + imaging_features = F.max_pool3d(imaging_features, kernel_size=imaging_features.shape[2:]) + + if self.projection_imaging is not None: + imaging_features = self.projection_imaging.forward(imaging_features) + imaging_features = torch.squeeze(torch.squeeze(imaging_features,dim=3),dim=2) - imaging_features = self.projection_imaging.forward(imaging_features) - imaging_features = torch.squeeze(torch.squeeze(imaging_features,dim=3),dim=2) return imaging_features - def forward(self, batch_dict): +class project_tabular(nn.Module): - imaging_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.imaging_features') + def __init__(self, projection_tabular: nn.Module = None): + super().__init__() + self.projection_tabular = projection_tabular + + def forward(self, tabular_features): + if self.projection_imaging is not None: + tabular_features = self.projection_tabular.forward(tabular_features) + tabular_features = torch.squeeze(torch.squeeze(tabular_features, dim=3), dim=2) + + return tabular_features + +# concat model +class TabularImagingConcat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, batch_dict): tabular_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.tabular_features') - imaging_features = self.fix_imaging(imaging_features) + imaging_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.imaging_features') res = torch.cat([tabular_features, imaging_features], dim=1) return res - #Tabular model class FuseModelTabularContinuousCategorical(torch.nn.Module): def __init__(self, continuous_tabular_input: Tuple[Tuple[str, int], ...], categorical_tabular_input: Tuple[Tuple[str, int], ...], - backbone_categorical_tabular: FuseMultilayerPerceptronBackbone, - backbone_continuous_tabular: FuseMultilayerPerceptronBackbone, + backbone_categorical_tabular: torch.nn.Module, + backbone_continuous_tabular: torch.nn.Module, + backbone_cat_tabular: torch.nn.Module, heads: Sequence[torch.nn.Module], ) -> None: super().__init__() self.continuous_tabular_input = continuous_tabular_input self.categorical_tabular_input = categorical_tabular_input self.backbone_categorical_tabular = backbone_categorical_tabular - self.backbone_cat_tabular = backbone_continuous_tabular - # self.add_module('backbone', self.backbone) + self.backbone_continuous_tabular = backbone_continuous_tabular + self.backbone_cat_tabular = backbone_cat_tabular self.heads = torch.nn.ModuleList(heads) self.add_module('heads', self.heads) def forward(self, batch_dict: Dict) -> Dict: - categorical_input = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) - categorical_embeddings = self.backbone_categorical_tabular(categorical_input) - continuous_input = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) - input_cat = torch.cat((categorical_embeddings, continuous_input), 1) + if self.backbone_categorical_tabular: + categorical_input = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) + categorical_embeddings = self.backbone_categorical_tabular(categorical_input) + else: + categorical_embeddings = FuseUtilsHierarchicalDict.get(batch_dict, self.categorical_tabular_input[0][0]) + + if self.backbone_continuous_tabular: + continuous_input = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) + continuous_embeddings = self.backbone_categorical_tabular(continuous_input) + else: + continuous_embeddings = FuseUtilsHierarchicalDict.get(batch_dict, self.continuous_tabular_input[0][0]) + + input_cat = torch.cat((categorical_embeddings, continuous_embeddings), 1) tabular_features = self.backbone_cat_tabular(input_cat) FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_features', tabular_features) @@ -177,10 +107,10 @@ def __init__(self, imaging_inputs: Tuple[Tuple[str, int], ...]=None, tabular_backbone: torch.nn.Module=None, imaging_backbone: torch.nn.Module=None, + tabular_projection: torch.nn.Module=None, + imaging_projection: torch.nn.Module = None, multimodal_backbone: torch.nn.Module=None, - tabular_heads: Sequence[torch.nn.Module]=None, - imaging_heads: Sequence[torch.nn.Module]=None, - multimodal_heads: Sequence[torch.nn.Module]=None, + heads: Sequence[torch.nn.Module]=None, ) -> None: super().__init__() @@ -198,27 +128,18 @@ def __init__(self, if self.multimodal_backbone: self.add_module('multimodal_backbone', multimodal_backbone) + self.tabular_projection = tabular_projection + if self.tabular_projection: + self.add_module('tabular_projection', tabular_projection) - self.tabular_heads = torch.nn.ModuleList(tabular_heads) - if self.tabular_heads: - self.add_module('tabular_heads', self.tabular_heads) - - self.imaging_heads = torch.nn.ModuleList(imaging_heads) - if self.imaging_heads: - self.add_module('imaging_heads', self.imaging_heads) - - self.multimodal_heads = torch.nn.ModuleList(multimodal_heads) - if self.multimodal_heads: - self.add_module('multimodal_heads', self.multimodal_heads) + self.imaging_projection = imaging_projection + if self.imaging_projection: + self.add_module('imaging_projection', imaging_projection) - def tabular_modules(self): - return [self.tabular_backbone, self.tabular_heads] - - def imaging_modules(self): - return [self.imaging_backbone, self.imaging_heads] + self.heads = torch.nn.ModuleList(heads) + if self.heads: + self.add_module('heads', self.heads) - def multimodal_modules(self): - return [self.multimodal_backbone, self.multimodal_heads] def forward(self, batch_dict: Dict) -> Dict: @@ -230,22 +151,24 @@ def forward(self, batch_dict: Dict) -> Dict: imaging_features = self.imaging_backbone.forward(imaging_input) FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_features', imaging_features) + if self.tabular_projection: + tabular_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.tabular_features') + tabular_features = self.tabular_projection.forward(tabular_features) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.tabular_features', tabular_features) + + if self.imaging_projection: + imaging_features = FuseUtilsHierarchicalDict.get(batch_dict, 'model.imaging_features') + imaging_features = self.imaging_projection.forward(imaging_features) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.imaging_features', imaging_features) + if self.multimodal_backbone: multimodal_features = self.multimodal_backbone.forward(batch_dict) FuseUtilsHierarchicalDict.set(batch_dict, 'model.multimodal_features', multimodal_features) # run through heads - if self.tabular_heads: - for head in self.tabular_heads: - batch_dict = head.forward(batch_dict) - - if self.imaging_heads: - for head in self.imaging_heads: - batch_dict = head.forward(batch_dict) - - if self.multimodal_heads: - for head in self.multimodal_heads: + if self.heads: + for head in self.heads: batch_dict = head.forward(batch_dict) return batch_dict['model'] @@ -253,11 +176,50 @@ def forward(self, batch_dict: Dict) -> Dict: if __name__ == '__main__': import torch + from fuse.models.heads.head_1d_classifier import FuseHead1dClassifier + from fuse.models.heads.head_global_pooling_classifier import FuseHeadGlobalPoolingClassifier batch_dict = {'data.continuous': torch.randn(8, 14), 'data.categorical': torch.randn(8, 63), 'data.image': torch.randn(8, 1, 2200, 1200)} + + heads_for_multimodal = { + 'multimodal_head': + [ + FuseHead1dClassifier( + head_name='multimodal', + conv_inputs=(('model.multimodal_features', + 256 * 2),), + num_classes=2, + ) + ], + 'tabular_head': + [ + FuseHead1dClassifier( + head_name='tabular', + conv_inputs=(('model.tabular_features', + 256),), + num_classes=2, + ) + ], + 'imaging_head': + [ + FuseHeadGlobalPoolingClassifier( + head_name='imaging', + dropout_rate=0.5, + layers_description=(256,), + conv_inputs=(('model.imaging_features', + 384),), + num_classes=2, + pooling="avg", + ) + ], + + + } + + # model = FuseModelTabularImaging( # continuous_tabular_input=(('data.continuous', 1),), # categorical_tabular_input=(('data.categorical', 1),), @@ -281,6 +243,8 @@ def forward(self, batch_dict: Dict) -> Dict: backbone_continuous_tabular = FuseMultilayerPerceptronBackbone( layers=[256],mlp_input_size=77), heads=None, ) + model_projector_imaging = project_imaging(projection_imaging=nn.Conv2d(384, 256, kernel_size=1, stride=1)) + model_projector_tabular = None model_imaging = FuseBackboneInceptionResnetV2(input_channels_num=1) model_multimodel = TabularImagingConcat() @@ -289,7 +253,10 @@ def forward(self, batch_dict: Dict) -> Dict: imaging_inputs=(('data.image', 1),), tabular_backbone=model_tabular, imaging_backbone=model_imaging, + tabular_projection=model_projector_tabular, + imaging_projection=model_projector_imaging, multimodal_backbone = model_multimodel, + heads=[heads_for_multimodal['multimodal_head'][0]], ) res = model.forward(batch_dict) a=1 \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/multimodel_parameters.py b/fuse_examples/classification/multimodality/multimodel_parameters.py index ad05b2fab..5902784c9 100644 --- a/fuse_examples/classification/multimodality/multimodel_parameters.py +++ b/fuse_examples/classification/multimodality/multimodel_parameters.py @@ -2,10 +2,8 @@ from fuse_examples.classification.multimodality.model_tabular_imaging import * from fuse.losses.loss_default import FuseLossDefault import torch.nn.functional as F -from fuse.metrics.classification.metric_auc import FuseMetricAUC -from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy +from fuse.eval.metrics.classification.metrics_classification_common import MetricAUCROC from fuse_examples.classification.multimodality.loss_multimodal_contrastive_learning import FuseLossMultimodalContrastiveLearning -from fuse.models.model_default import FuseModelDefault from fuse.models.heads.head_global_pooling_classifier import FuseHeadGlobalPoolingClassifier from fuse.models.heads.head_1d_classifier import FuseHead1dClassifier from fuse.models.model_ensemble import FuseModelEnsemble @@ -19,12 +17,16 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an continuous_tabular_input=(('data.continuous', 1),), categorical_tabular_input=(('data.categorical', 1),), backbone_categorical_tabular=train_common_params['tabular_encoder_categorical'], - backbone_continuous_tabular = train_common_params['tabular_encoder_continuous'], + backbone_continuous_tabular=train_common_params['tabular_encoder_continuous'], + backbone_cat_tabular=train_common_params['tabular_encoder_cat'], heads=None, ) model_imaging = train_common_params['imaging_encoder'] model_multimodel_concat = TabularImagingConcat() + model_projection_imaging = train_common_params['imaging_projector'] + model_projection_tabular = train_common_params['tabular_projector'] + heads_for_multimodal = { 'multimodal_head': [ @@ -71,10 +73,10 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an } metric_for_multimodal = { - 'multimodal_auc': FuseMetricAUC(pred_name='model.output.multimodal', target_name='data.gt'), - 'tabular_auc': FuseMetricAUC(pred_name='model.output.tabular', target_name='data.gt'), - 'imaging_auc': FuseMetricAUC(pred_name='model.output.imaging', target_name='data.gt'), - 'ensemble_auc':FuseMetricAUC(pred_name='model.output.tabular_ensemble_average', target_name='data.gt'), + 'multimodal_auc': MetricAUCROC(pred='model.output.multimodal', target='data.gt'), + 'tabular_auc': MetricAUCROC(pred='model.output.tabular', target='data.gt'), + 'imaging_auc': MetricAUCROC(pred='model.output.imaging', target='data.gt'), + 'ensemble_auc':MetricAUCROC(pred='model.output.tabular_ensemble_average', target='data.gt'), } ################################################ @@ -84,7 +86,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an train_common_params['model'] = FuseMultiModalityModel( tabular_inputs=(('data.continuous', 1), ('data.categorical', 1),), tabular_backbone=model_tabular, - tabular_heads=heads_for_multimodal['tabular_head'], + heads=heads_for_multimodal['tabular_head'], ) train_common_params['loss'] = { 'cls_loss': loss_for_multimodal['tabular_loss'], @@ -110,7 +112,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an train_common_params['model'] = FuseMultiModalityModel( imaging_inputs=(('data.image', 1),), imaging_backbone=model_imaging, - imaging_heads=heads_for_multimodal['imaging_head'], + heads=heads_for_multimodal['imaging_head'], ) train_common_params['loss'] = { 'cls_loss': loss_for_multimodal['imaging_loss'], @@ -133,9 +135,8 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an tabular_backbone=model_tabular, imaging_backbone=model_imaging, multimodal_backbone=model_multimodel_concat, - imaging_heads=heads_for_multimodal['imaging_head'], - tabular_heads=heads_for_multimodal['tabular_head'], - multimodal_heads=heads_for_multimodal['multimodal_head'], + imaging_projection= model_projection_imaging, + heads=[heads_for_multimodal['multimodal_head'][0]], ) train_common_params['loss'] = { @@ -176,22 +177,6 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'auc': metric_for_multimodal['ensemble_auc'], } - - # train_common_params['loss'] = { - # 'cls_loss': loss_for_multimodal['ensemble_loss'], - # - # } - # train_common_params['metrics'] = { - # 'auc': metric_for_multimodal['ensemble_auc'], - # } - # train_common_params['manager.learning_rate'] = 1e-5 - # train_common_params['manager.weight_decay'] = 0.001 - # - # train_common_params['optimizer'] = optim.Adam(train_common_params['model'].parameters(), lr=train_common_params['manager.learning_rate'], - # weight_decay=train_common_params['manager.weight_decay']) - # train_common_params['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(train_common_params['optimizer']) - - #Mo:different parameter if train_common_params['fusion_type'] == 'cotrastive': train_common_params['model'] = FuseModelTabularImaging( diff --git a/fuse_examples/classification/multimodality/runner.py b/fuse_examples/classification/multimodality/runner.py index bd25728d6..49ef58e2a 100644 --- a/fuse_examples/classification/multimodality/runner.py +++ b/fuse_examples/classification/multimodality/runner.py @@ -17,6 +17,7 @@ import os import logging from torch.utils.data.dataloader import DataLoader +import torch.nn as nn import fuse.utils.gpu as FuseUtilsGPU from fuse.utils.utils_debug import FuseUtilsDebug @@ -36,6 +37,7 @@ from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters from fuse_examples.classification.multimodality.mg_dataset_clinical_and_annotations import mg_clinical_annotations_dataset from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths +from fuse_examples.classification.multimodality.model_tabular_imaging import project_imaging, project_tabular @@ -65,7 +67,7 @@ root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" # Name of the experiment -experiment = 'late_fusion_non_visual' +experiment = 'late_fusion_non_annotated' # Path to cache data cache_path = root_data+'/mg_radiologic/' @@ -104,7 +106,7 @@ # best_epoch_source # if an epoch values are the best so far, the epoch is saved as a checkpoint. TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = { - 'source':'metrics.auc.macro_avg',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries + 'source':'metrics.auc',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries 'optimization': 'max', # can be either min/max 'on_equal_values': 'better', # can be either better/worse - whether to consider best epoch when values are equal @@ -119,9 +121,9 @@ TRAIN_COMMON_PARAMS['post_processing'] = PostProcessing(features_dic['continuous_clinical_feat'], features_dic['categorical_clinical_feat'], ['gt'], - features_dic['visual_feat'], - features_dic['non_visual_feat'], - use_imaging=False, use_non_imaging=True) + features_dic['annotated_feat'], + features_dic['non_annotated_feat'], + use_annotated=False, use_non_annotated=True) #define processors TRAIN_COMMON_PARAMS['imaging_processor'] = FuseMGInputProcessor @@ -133,14 +135,15 @@ TRAIN_COMMON_PARAMS['tabular_encoder_categorical'] = FuseMultilayerPerceptronBackbone( layers=[128, len(features_dic['categorical_clinical_feat'])], mlp_input_size=len(features_dic['categorical_clinical_feat'])) -TRAIN_COMMON_PARAMS['tabular_encoder_continuous'] = FuseMultilayerPerceptronBackbone( +TRAIN_COMMON_PARAMS['tabular_encoder_continuous'] = None +TRAIN_COMMON_PARAMS['tabular_encoder_cat'] = FuseMultilayerPerceptronBackbone( layers=[TRAIN_COMMON_PARAMS['tabular_feature_size']], mlp_input_size=len(features_dic['categorical_clinical_feat'])+\ len(features_dic['continuous_clinical_feat'])) TRAIN_COMMON_PARAMS['imaging_encoder'] = FuseBackboneInceptionResnetV2(input_channels_num=1) - - +TRAIN_COMMON_PARAMS['imaging_projector'] = project_imaging(projection_imaging=nn.Conv2d(TRAIN_COMMON_PARAMS['imaging_feature_size'], TRAIN_COMMON_PARAMS['tabular_feature_size'], kernel_size=1, stride=1)) +TRAIN_COMMON_PARAMS['tabular_projector'] = None TRAIN_COMMON_PARAMS['dataset_func'] = mg_clinical_annotations_dataset( tabular_filename=TRAIN_COMMON_PARAMS['paths']['tabular_filename'], From 9a515ff812264169551b48b987e3d16afb54bf45 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Tue, 3 May 2022 12:01:54 +0300 Subject: [PATCH 4/7] fix contrastive example --- .../loss_multimodal_contrastive_learning.py | 28 +++--- .../multimodality/model_tabular_imaging.py | 88 ------------------- .../multimodality/multimodel_parameters.py | 40 ++++++--- .../classification/multimodality/runner.py | 8 +- 4 files changed, 43 insertions(+), 121 deletions(-) diff --git a/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py b/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py index a2967dd77..e8ea8a440 100644 --- a/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py +++ b/fuse_examples/classification/multimodality/loss_multimodal_contrastive_learning.py @@ -4,10 +4,16 @@ from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict -def softcrossentropyloss(target, logits): +def softmax_cross_entropy_with_logits(target, logits): """ From the pytorch discussion Forum: https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501 + + Creates Fuse Dataset object for training, validation and test + :param target: A list of soft labels + :param logits: + :return: loss + """ logprobs = torch.nn.functional.log_softmax(logits, dim=1) loss = -(target * logprobs).sum() / logits.shape[0] @@ -15,6 +21,10 @@ def softcrossentropyloss(target, logits): class FuseLossMultimodalContrastiveLearning: + """ + Based on Multimodality contrastive loss as described in: + https://openreview.net/pdf?id=T4gXBOXoIUr + """ def __init__(self, imaging_representations: str = None, tabular_representations: str = None, @@ -43,21 +53,9 @@ def __call__(self, batch_dict: Dict) -> torch.Tensor: mask = torch.eq(torch.transpose(label_vec, 0, 1), label_vec).float() logits_imaging_tabular = torch.matmul(imaging_representations, torch.transpose(tabular_representations, 0, 1))/self.temperature logits_tabular_imaging = torch.matmul(tabular_representations, torch.transpose(imaging_representations, 0, 1))/self.temperature - loss_imaging_tabular = softcrossentropyloss(mask, logits_imaging_tabular)/torch.sum(mask, 0) - loss_tabular_imaging = softcrossentropyloss(mask, logits_tabular_imaging)/torch.sum(mask, 0) + loss_imaging_tabular = softmax_cross_entropy_with_logits(mask, logits_imaging_tabular)/torch.sum(mask, 0) + loss_tabular_imaging = softmax_cross_entropy_with_logits(mask, logits_tabular_imaging)/torch.sum(mask, 0) return self.alpha*loss_tabular_imaging.sum() + (1-self.alpha)*loss_imaging_tabular.sum() -if __name__ == '__main__': - import torch - - batch_dict = {'model.imaging_representations': torch.randn(3, 2), - 'model.tabular_representations': torch.randn(3, 2), - 'data.label': torch.empty(3, dtype=torch.long).random_(2)} - loss = FuseLossMultimodalContrastiveLearning(temperature=0.1, - imaging_representations='model.imaging_representations', - tabular_representations='model.tabular_representations', - label='data.label') - res = loss(batch_dict) - print('Loss output = ' + str(res)) diff --git a/fuse_examples/classification/multimodality/model_tabular_imaging.py b/fuse_examples/classification/multimodality/model_tabular_imaging.py index cdeae242a..dbd361575 100644 --- a/fuse_examples/classification/multimodality/model_tabular_imaging.py +++ b/fuse_examples/classification/multimodality/model_tabular_imaging.py @@ -172,91 +172,3 @@ def forward(self, batch_dict: Dict) -> Dict: batch_dict = head.forward(batch_dict) return batch_dict['model'] - - -if __name__ == '__main__': - import torch - from fuse.models.heads.head_1d_classifier import FuseHead1dClassifier - from fuse.models.heads.head_global_pooling_classifier import FuseHeadGlobalPoolingClassifier - - batch_dict = {'data.continuous': torch.randn(8, 14), - 'data.categorical': torch.randn(8, 63), - 'data.image': torch.randn(8, 1, 2200, 1200)} - - - heads_for_multimodal = { - 'multimodal_head': - [ - FuseHead1dClassifier( - head_name='multimodal', - conv_inputs=(('model.multimodal_features', - 256 * 2),), - num_classes=2, - ) - ], - 'tabular_head': - [ - FuseHead1dClassifier( - head_name='tabular', - conv_inputs=(('model.tabular_features', - 256),), - num_classes=2, - ) - ], - 'imaging_head': - [ - FuseHeadGlobalPoolingClassifier( - head_name='imaging', - dropout_rate=0.5, - layers_description=(256,), - conv_inputs=(('model.imaging_features', - 384),), - num_classes=2, - pooling="avg", - ) - ], - - - } - - - # model = FuseModelTabularImaging( - # continuous_tabular_input=(('data.continuous', 1),), - # categorical_tabular_input=(('data.categorical', 1),), - # imaging_inputs=(('data.patch.input.input_0', 1),),) - # - # res = model(batch_dict) - - # model = FuseModelTabularContinuousCategorical( - # continuous_tabular_input=(('data.continuous', 1),), - # categorical_tabular_input=(('data.categorical', 1),), - # backbone_categorical_tabular=FuseMultilayerPerceptronBackbone(layers=[128, 63],mlp_input_size=63), - # backbone_continuous_tabular = FuseMultilayerPerceptronBackbone( layers=[256],mlp_input_size=77), - # heads=None, - # ) - # res = model.forward(batch_dict) - - model_tabular = FuseModelTabularContinuousCategorical( - continuous_tabular_input=(('data.continuous', 1),), - categorical_tabular_input=(('data.categorical', 1),), - backbone_categorical_tabular=FuseMultilayerPerceptronBackbone(layers=[128, 63],mlp_input_size=63), - backbone_continuous_tabular = FuseMultilayerPerceptronBackbone( layers=[256],mlp_input_size=77), - heads=None, - ) - model_projector_imaging = project_imaging(projection_imaging=nn.Conv2d(384, 256, kernel_size=1, stride=1)) - model_projector_tabular = None - model_imaging = FuseBackboneInceptionResnetV2(input_channels_num=1) - model_multimodel = TabularImagingConcat() - - model = FuseMultiModalityModel( - tabular_inputs=(('data.continuous', 1),('data.categorical', 1),), - imaging_inputs=(('data.image', 1),), - tabular_backbone=model_tabular, - imaging_backbone=model_imaging, - tabular_projection=model_projector_tabular, - imaging_projection=model_projector_imaging, - multimodal_backbone = model_multimodel, - heads=[heads_for_multimodal['multimodal_head'][0]], - ) - res = model.forward(batch_dict) - a=1 \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/multimodel_parameters.py b/fuse_examples/classification/multimodality/multimodel_parameters.py index 5902784c9..2a9d73c27 100644 --- a/fuse_examples/classification/multimodality/multimodel_parameters.py +++ b/fuse_examples/classification/multimodality/multimodel_parameters.py @@ -136,6 +136,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an imaging_backbone=model_imaging, multimodal_backbone=model_multimodel_concat, imaging_projection= model_projection_imaging, + tabular_projection=model_projection_tabular, heads=[heads_for_multimodal['multimodal_head'][0]], ) @@ -177,25 +178,36 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'auc': metric_for_multimodal['ensemble_auc'], } - #Mo:different parameter - if train_common_params['fusion_type'] == 'cotrastive': - train_common_params['model'] = FuseModelTabularImaging( - continuous_tabular_input=(('data.continuous', 1),), - categorical_tabular_input=(('data.categorical', 1),), - imaging_inputs=(('data.image', 1),), - backbone_categorical_tabular=train_common_params['tabular_encoder_categorical'], - backbone_continuous_tabular=train_common_params['tabular_encoder_continuous'], - backbone_imaging=train_common_params['imaging_encoder'], - - ) + if train_common_params['fusion_type'] == 'contrastive': + train_common_params['model'] = FuseMultiModalityModel( + tabular_inputs=(('data.continuous', 1), ('data.categorical', 1),), + imaging_inputs=(('data.image', 1),), + tabular_backbone=model_tabular, + imaging_backbone=model_imaging, + multimodal_backbone=None, + imaging_projection= model_projection_imaging, + tabular_projection=model_projection_tabular, + ) - train_common_params['loss'] = FuseLossMultimodalContrastiveLearning( - imaging_representations='model.imaging_representations', - tabular_representations='model.tabular_representations', + train_common_params['loss'] = {'cls_loss': FuseLossMultimodalContrastiveLearning( + imaging_representations='model.imaging_features', + tabular_representations='model.tabular_features', label='data.gt', temperature=0.1, alpha=0.5) + } train_common_params['metrics'] = None + train_common_params['manager.learning_rate'] = 1e-4 + train_common_params['manager.weight_decay'] = 1e-4 + train_common_params['manager.momentum'] = 0.9 + train_common_params['manager.step_size'] = 150 + train_common_params['manager.gamma'] = 0.1 + train_common_params['optimizer'] = optim.SGD(train_common_params['model'].parameters(), + lr=train_common_params['manager.learning_rate'], + momentum=train_common_params['manager.momentum'], + weight_decay=train_common_params['manager.weight_decay']) + train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], + gamma=train_common_params['manager.gamma']) return train_common_params,infer_common_params,analyze_common_params \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/runner.py b/fuse_examples/classification/multimodality/runner.py index 49ef58e2a..7bbcf6a07 100644 --- a/fuse_examples/classification/multimodality/runner.py +++ b/fuse_examples/classification/multimodality/runner.py @@ -67,13 +67,13 @@ root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" # Name of the experiment -experiment = 'late_fusion_non_annotated' +experiment = 'contrastive_non_annotated' # Path to cache data cache_path = root_data+'/mg_radiologic/' paths = multimodal_paths(dataset_name, root_data, root, experiment, cache_path) TRAIN_COMMON_PARAMS['paths'] = paths -TRAIN_COMMON_PARAMS['fusion_type'] = 'late_fusion' +TRAIN_COMMON_PARAMS['fusion_type'] = 'contrastive' ###################################### # Inference Common Params ###################################### @@ -106,8 +106,8 @@ # best_epoch_source # if an epoch values are the best so far, the epoch is saved as a checkpoint. TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = { - 'source':'metrics.auc',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries - 'optimization': 'max', # can be either min/max + 'source':'losses.cls_loss',#'metrics.auc',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries + 'optimization': 'min', # can be either min/max 'on_equal_values': 'better', # can be either better/worse - whether to consider best epoch when values are equal } From 0a452fc3ee4bed5295501660515644e068701b56 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Wed, 1 Jun 2022 13:58:23 +0300 Subject: [PATCH 5/7] knight multimodal example --- .../multimodality/data_curation.py | 257 +++++++++++- .../multimodality/mg_dataset.py | 2 +- .../multimodality/model_tabular_imaging.py | 206 +++++++++- .../multimodality/multimodal_paths.py | 25 ++ .../multimodality/multimodel_parameters.py | 96 ++++- .../classification/multimodality/runner.py | 23 +- .../multimodality/runner_knight.py | 389 ++++++++++++++++++ 7 files changed, 955 insertions(+), 43 deletions(-) create mode 100644 fuse_examples/classification/multimodality/runner_knight.py diff --git a/fuse_examples/classification/multimodality/data_curation.py b/fuse_examples/classification/multimodality/data_curation.py index 26f78efa3..3046c8c63 100644 --- a/fuse_examples/classification/multimodality/data_curation.py +++ b/fuse_examples/classification/multimodality/data_curation.py @@ -4,10 +4,33 @@ from MedicalAnalyticsCore.DatabaseUtils import tableNames from MedicalAnalyticsCore.DatabaseUtils import db_utils as db +import pandas as pd +import numpy as np +from sklearn.model_selection import StratifiedKFold +import pickle + +#------------------Baseline +def apply_gluon_baseline(train_set,test_set,label,save_path): + from autogluon.tabular import TabularPredictor + + predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) + results = predictor.fit_summary(show_plot=True) + + # Inference time: + y_test = test_set[label] + test_data = test_set.drop(labels=[label], + axis=1) # delete labels from test data since we wouldn't have them in practice + print(test_data.head()) + + predictor = TabularPredictor.load( + save_path) + y_pred = predictor.predict_proba(test_data) + perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) def mg_data_curation(imaging_filename): REVISION_DATE = '20200915' TableResolver().set_revision(REVISION_DATE) + # extand dataset by adding 'baptist', 'froedtert', 'miami', 'ucsd' revision = {'prefix': 'sentara', 'suffix': REVISION_DATE} engine = Connection().get_engine() @@ -18,7 +41,7 @@ def mg_data_curation(imaging_filename): remove_excluded=False, remove_less_than_4views=False, \ load_from_file=False, save_to_file=False) - my_providers = ['sentara'] + my_providers = ['sentara']#extand dataset by adding 'baptist', 'froedtert', 'miami', 'ucsd' df = df_with_findings.loc[df_with_findings['provider'].isin(my_providers)] # fixing assymetry asymmetries = ['asymmetry', 'developing asymmetry', 'focal asymmetry', 'global asymmetry'] @@ -39,20 +62,224 @@ def mg_data_curation(imaging_filename): df = df[(is_digital) & (is_pathology) & (is_biopsy)] df.to_csv(imaging_filename) -#------------------Baseline -def apply_gluon_baseline(train_set,test_set,label,save_path): - from autogluon.tabular import TabularPredictor +def encode_one_hot(original_dataframe, features_to_encode): + for features in features_to_encode: + original_dataframe[features] = original_dataframe[features].astype('category', copy=False) + dummies = pd.get_dummies(data=original_dataframe[features_to_encode]) + res = pd.concat([original_dataframe, dummies], axis=1) + res = res.drop(features_to_encode, axis=1) + return res + + +if __name__ == '__main__': + clinical_data_path = '/gpfs/haifa/projects/m/msieve_dev3/usr/Tal/my_research/virtual_biopsy/experiments/biopsy-mal_benign_plus_clinical/' + table_data_files = ['demographics_and_breast_density.csv', + 'dicom_tags_extracted.csv', + 'max_prev_birad_class.csv' + ] + is_table_data_files = ['distortions_20210317.csv', + 'spiculations_20210317.csv', + 'susp_calcifications_20210317.csv', + ] + path = clinical_data_path + + REVISION_DATE = '20210317' + TableResolver().set_revision(REVISION_DATE) + revision = {'prefix': '', 'suffix': REVISION_DATE} + engine = Connection().get_engine() + + df_with_findings = get_annotations_and_findings(engine, revision, + exam_types=None, viewpoints=None, # ['CC','MLO'], + include_findings=True, remove_invalids=True, + remove_heldout=False, \ + remove_excluded=False, remove_less_than_4views=False, \ + load_from_file=False, save_to_file=False) + + # negative 5011 positive 3464,negative high risk 907 - finding biopsy + # negative 4906 positive 3561,negative high risk 915 - biopsy + cured_df = mg_data_curation(df_with_findings) + + # add dicom tags + dicom_tags_extracted_table = pd.read_csv(clinical_data_path + table_data_files[1]) + cured_df_w_dicoms = pd.merge(cured_df, dicom_tags_extracted_table, how='inner', + left_on=['xml_url'], right_on=['xml_url'], suffixes=('', '_')) + + # add demographics tags + demographics_and_breast_density_table = pd.read_csv(clinical_data_path + table_data_files[0]) + demographics_and_breast_density_table_cured = \ + demographics_and_breast_density_table[ + demographics_and_breast_density_table['study_id'].isin(cured_df_w_dicoms['study_id'])] + cured_df_w_dicoms_w_demographic = pd.merge(cured_df_w_dicoms, demographics_and_breast_density_table_cured, + how='inner', + left_on=['provider', 'patient_id', 'study_id', 'breast_density'], + right_on=['provider', 'patient_id', 'study_id', 'breast_density']) + + # add max_prev_birad_class + max_prev_birad_class_table = pd.read_csv(clinical_data_path + table_data_files[2]) + max_prev_birad_class_table_cured = \ + max_prev_birad_class_table[ + max_prev_birad_class_table['study_id'].isin(cured_df_w_dicoms['study_id'])] + cured_df_w_dicoms_w_demographic_w_birad = pd.merge(cured_df_w_dicoms_w_demographic, + max_prev_birad_class_table_cured, how='inner', + left_on=['provider', 'patient_id', 'study_id'], + right_on=['provider', 'patient_id', 'study_id']) + + # add features from report + is_distortions_table = pd.read_csv(clinical_data_path + is_table_data_files[0]) + cured_df_w_dicoms_w_demographic_w_birad['is_distortions'] = 0 + cured_df_w_dicoms_w_demographic_w_birad['is_distortions'][ + (cured_df_w_dicoms_w_demographic_w_birad['provider'].isin(is_distortions_table['provider'])) & + (cured_df_w_dicoms_w_demographic_w_birad['patient_id'].isin(is_distortions_table['patient_id'])) & + (cured_df_w_dicoms_w_demographic_w_birad['study_id'].isin(is_distortions_table['study_id']))] = 1 + + is_spiculations_table = pd.read_csv(clinical_data_path + is_table_data_files[1]) + cured_df_w_dicoms_w_demographic_w_birad['is_spiculations'] = 0 + cured_df_w_dicoms_w_demographic_w_birad['is_spiculations'][ + (cured_df_w_dicoms_w_demographic_w_birad['provider'].isin(is_spiculations_table['provider'])) & + (cured_df_w_dicoms_w_demographic_w_birad['patient_id'].isin(is_spiculations_table['patient_id'])) & + (cured_df_w_dicoms_w_demographic_w_birad['study_id'].isin(is_spiculations_table['study_id']))] = 1 + + is_susp_calcifications_table = pd.read_csv(clinical_data_path + is_table_data_files[2]) + cured_df_w_dicoms_w_demographic_w_birad['is_susp_calcifications'] = 0 + cured_df_w_dicoms_w_demographic_w_birad['is_susp_calcifications'][ + (cured_df_w_dicoms_w_demographic_w_birad['provider'].isin(is_susp_calcifications_table['provider'])) & + (cured_df_w_dicoms_w_demographic_w_birad['patient_id'].isin(is_susp_calcifications_table['patient_id'])) & + (cured_df_w_dicoms_w_demographic_w_birad['study_id'].isin(is_susp_calcifications_table['study_id']))] = 1 + + cured_df_w_dicoms_w_demographic_w_birad.to_csv(path + 'curated_set_full_table_v2.csv') + cured_df_w_dicoms_w_demographic_w_birad = cured_df_w_dicoms_w_demographic_w_birad.drop( + cured_df_w_dicoms_w_demographic_w_birad[cured_df_w_dicoms_w_demographic_w_birad['contour'] == '{}'].index) + scanned_images = [ + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/382590/20100802_MM10071593/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.185594307.2225150464.3063965447_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/382590/20100802_MM10071593/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.185594307.2359368192.3063965447_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40052219/20130327_MM13037621/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.269456599.1298734592.3360623261_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40052219/20130327_MM13037621/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.269456599.1365843456.3360623261_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40113784/20090420_MM09036951/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.152018509.307699200.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40113784/20090420_MM09036951/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.152018509.374808064.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40124614/20121210_MM12151081/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.269446726.400694784.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40180058/20130219_MM13021798/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.269428388.2811436544.3360623261_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/40180058/20130219_MM13021798/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.269428900.2903318016.3360623261_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/50306630/20140206_MG140206004350/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.303013245.132455936.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/50306630/20140206_MG140206004350/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.303013245.4226096640.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/59991270/20100514_MM10044896/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.185536601.641998336.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/59991270/20100514_MM10044896/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.185536601.776216064.3360627099_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/9985667/20090701_MM09062371/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.151978415.1949178112.179941807_8bit.xml', + '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Sentara/9985667/20090701_MM09062371/MG/8_bit/metadata/annotations/1.2.392.200036.9125.4.0.151978415.2150570240.179941807_8bit.xml'] + cured_df_w_dicoms_w_demographic_w_birad = cured_df_w_dicoms_w_demographic_w_birad[ + ~cured_df_w_dicoms_w_demographic_w_birad['xml_url'].isin(scanned_images)] + cured_df_w_dicoms_w_demographic_w_birad.to_csv(path + 'curated_set_full_table_v2_filtered.csv') + # Load csv file after manually fixing outlier row + df = pd.read_csv(path + 'curated_set_full_table_v2_filtered_fixed.csv') + # Fixing missing values + df['body_part_thickness'][ + df.xml_url == '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Baptist/FFF452F4CAC130D7556371C70B3459A4/20110206_73/MG/8_bit/metadata/annotations/FO-8874352202443175339_8bit_mn.xml'] = 183 + df['body_part_thickness'][ + df.xml_url == '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Baptist/6D3511BCDC25A14CD5F37D8BAE2B3F47/20111218_117/MG/8_bit/metadata/annotations/FO-3161254379748920352_8bit_mn.xml'] = 181 + df['body_part_thickness'][ + df.xml_url == '/gpfs/haifa/projects/m/msieve/MedicalSieve/PatientData/Baptist/225742C5EAC7484DCF4E59BDFA7D3F97/20140812_840/MG/8_bit/metadata/annotations/FO-6768169063846331327_8bit_mn.xml'] = 180 + df['breast_density'][df.breast_density == 'undefined'] = df['breast_density'].mode()[0] + df['birad'][(df.birad == 'not_applicable') | (df.birad == 'undefined')] = df['final_side_birad'][ + (df.birad == 'not_applicable') | (df.birad == 'undefined')] + df['age'][df.age == 0] = df['age'].mode()[0] + max_prev_birad_class_lst = [0, 1, 2, 3] + df['max_prev_birad_class'][~df.max_prev_birad_class.isin(max_prev_birad_class_lst)] = df['max_prev_birad_class'].mode()[0] + race_lst = ['african american', 'amer indian/alaskan', 'American Indian', 'Asian', 'Black', 'caucasian', 'hispanic', 'other', 'Pacific Islander', 'unknown', 'White'] + df['race'][~df.race.isin(race_lst)] = 'unknown' + longitudinal_change_lst = ['longitudinal_change', 'increase', 'new appearance', 'not_applicable', 'stable'] + df['longitudinal_change'][~df.longitudinal_change.isin(longitudinal_change_lst)] = 'unknown' + + df[['findings_x_max', 'findings_y_max']].multiply(df['x_pixel_spacing'], axis="index") + df['findings_size'].multiply(df['x_pixel_spacing'].pow(2), axis="index") + + id_lst = ['patient_id', 'xml_url'] + clinical_features_lst = ['breast_density', 'final_side_birad', 'side', 'birad', 'calcification', 'findings_size', + 'findings_x_max', + 'findings_y_max', 'longitudinal_change', 'type', 'DistanceSourceToPatient', + 'DistanceSourceToDetector', 'x_pixel_spacing', + 'XRayTubeCurrent', 'CompressionForce', 'exposure_time', 'KVP', 'body_part_thickness', + 'RelativeXRayExposure', 'exposure_in_mas', + 'age', 'race', 'max_prev_birad_class', 'is_distortions', 'is_spiculations', + 'is_susp_calcifications', 'biopsy'] + dataset = df[id_lst + clinical_features_lst] + dataset.to_csv(path + 'dataset.csv') + + # Convert biopsy into 0/1 label + dataset['biopsy'][dataset.biopsy == 'positive'] = 1 + dataset['biopsy'][(dataset.biopsy == 'negative') | (dataset.biopsy == 'negative high risk')] = 0 + + dataset = dataset.drop_duplicates(subset=['patient_id']) + dataset.to_csv(path + 'dataset_unique.csv') + + #Convert categorical data into numerical + dataset['final_side_birad'] = dataset['final_side_birad'].astype('category') + dataset['side'] = dataset['side'].astype('category') + dataset['birad'] = dataset['birad'].astype('category') + dataset['calcification'] = dataset['calcification'].astype('category') + dataset['longitudinal_change'] = dataset['longitudinal_change'].astype('category') + dataset['type'] = dataset['type'].astype('category') + dataset['race'] = dataset['race'].astype('category') + cat_columns = dataset.select_dtypes(['category']).columns + dataset[cat_columns] = dataset[cat_columns].apply(lambda x: x.cat.codes) + dataset.to_csv(path + 'dataset_numerical_unique.csv') + dataset['max_prev_birad_class'] = dataset['max_prev_birad_class'].astype(int) + features_to_encode = ['breast_density', 'final_side_birad', 'birad', + 'calcification', 'longitudinal_change', 'type', 'race', 'max_prev_birad_class'] + dataset = encode_one_hot(dataset, features_to_encode) + dataset = dataset.rename(columns={'patient_id': "sample_desc"}) + + FOLDS_NUMBER = 6 + X = dataset['sample_desc'].values + y = np.zeros(X.shape) + y[dataset['biopsy'].values > 0] = 1 + kfold = StratifiedKFold(n_splits=FOLDS_NUMBER, shuffle=True, random_state=1) + # enumerate the splits and summarize the distributions + db = {} + f = 0 + for train_ix, test_ix in kfold.split(X, y): + # select rows + train_X, test_X = X[train_ix], X[test_ix] + train_y, test_y = y[train_ix], y[test_ix] + # summarize train and test composition + train_0, train_1 = len(train_y[train_y == 0]), len(train_y[train_y == 1]) + test_0, test_1 = len(test_y[test_y == 0]), len(test_y[test_y == 1]) + print('>Train: 0=%d, 1=%d, Test: 0=%d, 1=%d' % (train_0, train_1, test_0, test_1)) + print(test_X) + tt = dataset[dataset['sample_desc'].isin(test_X)] + db['data_fold' + str(f)] = tt + f += 1 + + temp = db['data_fold0'] + for i in range(1,4): + temp = temp.append(db['data_fold'+str(i)], ignore_index=True) + train = temp + validation = db['data_fold4'] + heldout = db['data_fold5'] + + features_to_normalize = ['findings_size', 'findings_x_max', 'findings_y_max', 'DistanceSourceToPatient', + 'DistanceSourceToDetector', 'x_pixel_spacing', 'XRayTubeCurrent', 'CompressionForce', + 'exposure_time', 'KVP', 'body_part_thickness', 'RelativeXRayExposure', 'exposure_in_mas', + 'age'] + + for feature in features_to_normalize: + train[feature] = (train[feature] - train[feature].mean()) / train[feature].std() + validation[feature] = (validation[feature] - validation[feature].mean()) / validation[feature].std() + heldout[feature] = (heldout[feature] - heldout[feature].mean()) /heldout[feature].std() + + with open(path + 'dataset_MG_clinical_train' + '.pickle','wb') as handle: + pickle.dump(train, handle, protocol=pickle.HIGHEST_PROTOCOL) + with open(path + 'dataset_MG_clinical_train_xml.txt', 'w') as f: + f.write(train['xml_url'].str.cat(sep='\n')) + train.to_csv(path + 'dataset_MG_clinical_train.csv') + with open(path + 'dataset_MG_clinical_validation' + '.pickle', 'wb') as handle: + pickle.dump(validation, handle, protocol=pickle.HIGHEST_PROTOCOL) + with open(path + 'dataset_MG_clinical_validation_xml.txt', 'w') as f: + f.write(validation['xml_url'].str.cat(sep='\n')) + validation.to_csv(path + 'dataset_MG_clinical_validation.csv') + with open(path + 'dataset_MG_clinical_heldout' + '.pickle', 'wb') as handle: + pickle.dump(heldout, handle, protocol=pickle.HIGHEST_PROTOCOL) + with open(path + 'dataset_MG_clinical_heldout_xml.txt', 'w') as f: + f.write(heldout['xml_url'].str.cat(sep='\n')) + heldout.to_csv(path + 'dataset_MG_clinical_heldout.csv') - predictor = TabularPredictor(label=label, path=save_path, eval_metric='roc_auc').fit(train_set) - results = predictor.fit_summary(show_plot=True) - # Inference time: - y_test = test_set[label] - test_data = test_set.drop(labels=[label], - axis=1) # delete labels from test data since we wouldn't have them in practice - print(test_data.head()) - predictor = TabularPredictor.load( - save_path) - y_pred = predictor.predict_proba(test_data) - perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True) \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/mg_dataset.py b/fuse_examples/classification/multimodality/mg_dataset.py index 7bce70174..1a4152203 100644 --- a/fuse_examples/classification/multimodality/mg_dataset.py +++ b/fuse_examples/classification/multimodality/mg_dataset.py @@ -14,7 +14,7 @@ from fuse_examples.classification.multimodality.dataset import imaging_tabular_dataset from fuse.data.dataset.dataset_default import FuseDatasetDefault -from fuse_examples.classification.MG_CMMD.input_processor import FuseMGInputProcessor +from fuse_examples.classification.cmmd.input_processor import FuseMGInputProcessor from fuse.data.processor.processor_dataframe import FuseProcessorDataFrame diff --git a/fuse_examples/classification/multimodality/model_tabular_imaging.py b/fuse_examples/classification/multimodality/model_tabular_imaging.py index dbd361575..cfa0774e8 100644 --- a/fuse_examples/classification/multimodality/model_tabular_imaging.py +++ b/fuse_examples/classification/multimodality/model_tabular_imaging.py @@ -1,10 +1,17 @@ import torch import torch.nn as nn import torch.nn.functional as F -from fuse.models.backbones.backbone_mlp import FuseMultilayerPerceptronBackbone +from torch import Tensor +from torch.hub import load_state_dict_from_url +from torchvision.models.video.resnet import VideoResNet, BasicBlock, Conv3DSimple, BasicStem, model_urls +from typing import Dict, Tuple, Any, List, Sequence, Callable from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict -from typing import Dict, Tuple, Sequence -from fuse.models.backbones.backbone_inception_resnet_v2 import FuseBackboneInceptionResnetV2 +from fuse.models.model_default import FuseModelDefault +from fuse.models.heads.head_3D_classifier import FuseHead3dClassifier + +#----------------------------------------------------------------- +#Encoders fusion models +#----------------------------------------------------------------- class project_imaging(nn.Module): @@ -22,12 +29,14 @@ def forward(self, imaging_features): imaging_features = F.max_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) else: imaging_features = F.max_pool3d(imaging_features, kernel_size=imaging_features.shape[2:]) + imaging_features = torch.squeeze(imaging_features,len(imaging_features.shape)-1) elif self.pooling == 'avg': if self.dim == '2d': imaging_features = F.avg_pool2d(imaging_features, kernel_size=imaging_features.shape[2:]) else: imaging_features = F.max_pool3d(imaging_features, kernel_size=imaging_features.shape[2:]) + imaging_features = torch.squeeze(imaging_features,len(imaging_features.shape)-1) if self.projection_imaging is not None: imaging_features = self.projection_imaging.forward(imaging_features) @@ -35,7 +44,6 @@ def forward(self, imaging_features): return imaging_features - class project_tabular(nn.Module): def __init__(self, projection_tabular: nn.Module = None): @@ -172,3 +180,193 @@ def forward(self, batch_dict: Dict) -> Dict: batch_dict = head.forward(batch_dict) return batch_dict['model'] + + +#----------------------------------------------------------------- +#Interactive models +#----------------------------------------------------------------- + +def channel_multiplication(vector, matrix): + return matrix * vector.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + +class FuseBackboneResnet3DInteractive(VideoResNet): + """ + 3D model classifier (ResNet architecture" + """ + + def __init__(self, + conv_inputs: Tuple[Tuple[str, int], ...] = (('data.image', 1),), + fcn_inputs: Tuple[Tuple[str, int], ...] = (('data.input.clinical.all', 1),), + fcn_layers: List[int] = [64, 64, 128, 256, 512], #VideoResNet layers + fcn_input_size: int = 11, + interact_function: Callable = channel_multiplication, + cnn_interact_function: Callable = None, + fcn_cnn_layers_interactions: List[int] = None, + fcn_cnn_layers_parallels: List[int] = None, + use_relu_in_fcn: bool = True, + use_batcn_norm_in_fcn: bool = False, + pretrained: bool = False, in_channels: int = 1, + name: str = "r3d_18") -> None: + """ + Create 3D ResNet model + :param pretrained: Use pretrained weights + :param in_channels: number of input channels + :param name: model name. currently only 'r3d_18' is supported + """ + # init parameters per required backbone + init_parameters = { + 'r3d_18': {'block': BasicBlock, + 'conv_makers': [Conv3DSimple] * 4, + 'layers': [2, 2, 2, 2], + 'stem': BasicStem}, + }[name] + + # init original model + super().__init__(**init_parameters) + + # load pretrained parameters if required + if pretrained: + state_dict = load_state_dict_from_url(model_urls[name]) + self.load_state_dict(state_dict) + + # ================================= + self.use_relu_in_fcn = use_relu_in_fcn + self.use_batcn_norm_in_fcn = use_batcn_norm_in_fcn + self.cnn_interact_function = cnn_interact_function + + fcn_data = [nn.Linear(fcn_input_size, fcn_layers[0])] + if self.use_relu_in_fcn: + fcn_data.append(nn.ReLU(inplace=False)) + if self.use_batcn_norm_in_fcn: + fcn_data.append(nn.BatchNorm1d(fcn_layers[0], eps=0.001, momentum=0.01, affine=True)) + + for layer_idx in range(len(fcn_layers) - 1): + fcn_data.append(nn.Linear(fcn_layers[layer_idx], fcn_layers[layer_idx + 1])) + fcn_data.append(nn.ReLU(inplace=False)) if self.use_relu_in_fcn else None + fcn_data.append(nn.BatchNorm1d(fcn_layers[layer_idx + 1], eps=0.001, momentum=0.01, + affine=True)) if self.use_batcn_norm_in_fcn else None + + self.interactive_fcn = nn.ModuleList(fcn_data) + if interact_function is None and fcn_cnn_layers_interactions is not None: + assert "fcn_cnn_layers_interactions are defined but no interactive function is provided" + self.fcn_interact_function = interact_function + self.fcn_cnn_layers_parallels = fcn_cnn_layers_parallels or range( + len(self.interactive_fcn)) # either specified or all cnn layers + self.fcn_cnn_layers_interactions = fcn_cnn_layers_interactions or self.fcn_cnn_layers_parallels # either specified or all parallel layers + + #================================= + # save input parameters + self.pretrained = pretrained + self.in_channels = in_channels + # override the first convolution layer to support any number of input channels + self.stem = nn.Sequential( + nn.Conv3d(self.in_channels, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), + padding=(1, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True) + ) + self.conv_inputs = conv_inputs + self.fcn_inputs = fcn_inputs + + def features(self, batch_dict: Dict) -> Any: + """ + Extract spatial features - given a 3D tensor + :param x: Input tensor - shape: [batch_size, channels, z, y, x] + :return: spatial features - shape [batch_size, n_features, z', y', x'] + """ + + conv_input = torch.cat([FuseUtilsHierarchicalDict.get(batch_dict, conv_input[0]) for conv_input in self.conv_inputs], 1) + fcn_input = torch.cat([FuseUtilsHierarchicalDict.get(batch_dict, fcn_input[0]) for fcn_input in self.fcn_inputs], 1) + + + interactive_idx=0 + x = self.stem(conv_input) + out, y, interactive_idx = self.apply_interactive_fcn(interactive_idx, x, fcn_input, 0) + out = self.layer1(out) + out, y, interactive_idx = self.apply_interactive_fcn(interactive_idx, out, y, 1) + out = self.layer2(out) + out, y, interactive_idx = self.apply_interactive_fcn(interactive_idx, out, y, 2) + out = self.layer3(out) + out, y, interactive_idx = self.apply_interactive_fcn(interactive_idx, out, y, 3) + out = self.layer4(out) + out, y, interactive_idx = self.apply_interactive_fcn(interactive_idx, out, y, 4) + + return out + + def apply_interactive_fcn(self, interactive_idx, x, y, cnn_layer_idx): + if cnn_layer_idx in self.fcn_cnn_layers_parallels: # only if this resnet layer should be interacted with fcn + if interactive_idx < len(self.interactive_fcn): + y = self.interactive_fcn[interactive_idx](y) # fully connected layer + interactive_idx += 1 + if self.use_relu_in_fcn: + y = self.interactive_fcn[interactive_idx](y) # ReLU + interactive_idx += 1 + if self.use_batcn_norm_in_fcn: + y = self.interactive_fcn[interactive_idx](y) # BatchNorm + interactive_idx += 1 + if self.fcn_interact_function is not None: + # only apply interact function if the layer is in the resnet layer we want to interact with + if cnn_layer_idx in self.fcn_cnn_layers_interactions: + x = self.fcn_interact_function(y, x) + if self.cnn_interact_function is not None: + # only apply interact function if the layer is in the resnet layer we want to interact with + if cnn_layer_idx in self.fcn_cnn_layers_interactions: + y = self.cnn_interact_function(y, x) + + return x, y, interactive_idx + + def forward(self, x: Tensor) -> Tuple[Tensor, None, None, None]: # type: ignore + """ + Forward pass. 3D global classification given a volume + :param x: Input volume. shape: [batch_size, channels, z, y, x] + :return: logits for global classification. shape: [batch_size, n_classes]. + """ + x = self.features(x) + return x + +class FuseModelDefaultInteractive(FuseModelDefault): + def __init__(self, + conv_inputs: Tuple[Tuple[str, int], ...] = (('data.input', 1),), + cnn_inputs: Tuple[Tuple[str, int], ...] = (('data.clinical', 1),), + backbone: torch.nn.Module = None, + heads: Sequence[torch.nn.Module] = None, + freeze_backbone=False, + ) -> None: + """ + Default Fuse model - convolutional neural network with multiple heads + :param conv_inputs: batch_dict name for model input and its number of input channels, imaging + :param cnn_inputs: batch_dict name for model input and its number of input channels, tabular feature + + :param backbone: PyTorch backbone module - a convolutional neural network + :param heads: Sequence of head modules + """ + FuseModelDefault.__init__(self,conv_inputs=conv_inputs,heads=heads,backbone=backbone) + + self.cnn_inputs = cnn_inputs + self.freeze_backbone = freeze_backbone + + def train(self, model: bool=True): + if self.freeze_backbone: + self.backbone.eval() + self.backbone.interactive_fcn.train() + self.heads.train() + else: + self.backbone.train() + self.heads.train() + + def forward(self, + batch_dict: Dict) -> Dict: + """ + Forward function of the model + :param input: Tensor [BATCH_SIZE, 1, H, W] + :return: classification scores - [BATCH_SIZE, num_classes] + """ + + features = self.backbone(batch_dict) + FuseUtilsHierarchicalDict.set(batch_dict, 'model.backbone_features', features) + + for head in self.heads: + batch_dict = head.forward(batch_dict) + + return batch_dict['model'] + diff --git a/fuse_examples/classification/multimodality/multimodal_paths.py b/fuse_examples/classification/multimodality/multimodal_paths.py index 4112c8f12..fc3aa2cd5 100644 --- a/fuse_examples/classification/multimodality/multimodal_paths.py +++ b/fuse_examples/classification/multimodality/multimodal_paths.py @@ -1,4 +1,7 @@ import os +import pathlib +import pandas as pd +from fuse.utils.rand.seed import Seed def multimodal_paths(dataset_name,root_data,root, experiment,cache_path): if dataset_name=='mg_clinical': @@ -43,6 +46,28 @@ def multimodal_paths(dataset_name,root_data,root, experiment,cache_path): # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. 'cache_dir': os.path.join(cache_path), 'inference_dir': os.path.join(root_data,'model_mg_radiologist_usa/'+experiment)} + if dataset_name == 'knight': + # read train/val splits file. for convenience, we use the one + # auto-generated by the nnU-Net framework for the KiTS21 data + dir_path = '/projects/msieve_dev2/usr/Tal/git_repos_multimodality/fuse-med-ml/fuse_examples/classification/knight/baseline' + splits = pd.read_pickle(os.path.join(dir_path, 'splits_final.pkl')) + # For this example, we use split 0 out of the 5 available cross validation splits + split = splits[0] + # set constant seed for reproducibility. + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" # required for pytorch deterministic mode + rand_gen = Seed.set_seed(1234, deterministic_mode=False) + + paths = { + # paths + 'data_dir': '/projects/msieve/MedicalSieve/PatientData/KNIGHT/', + 'split': split, + 'seed':rand_gen, + 'force_reset_model_dir': False, + # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. + 'cache_dir': os.path.join(cache_path), + 'rand_gen':rand_gen, + 'model_dir': os.path.join(root_data, 'knight/' + experiment), + 'inference_dir': os.path.join(root_data, 'knight/' + experiment)} return paths \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/multimodel_parameters.py b/fuse_examples/classification/multimodality/multimodel_parameters.py index 2a9d73c27..26c066183 100644 --- a/fuse_examples/classification/multimodality/multimodel_parameters.py +++ b/fuse_examples/classification/multimodality/multimodel_parameters.py @@ -7,10 +7,12 @@ from fuse.models.heads.head_global_pooling_classifier import FuseHeadGlobalPoolingClassifier from fuse.models.heads.head_1d_classifier import FuseHead1dClassifier from fuse.models.model_ensemble import FuseModelEnsemble +from fuse.models.heads.head_3D_classifier import FuseHead3dClassifier def multimodal_parameters(train_common_params: dict,infer_common_params: dict,analyze_common_params: dict): - + num_classes = train_common_params['num_classes'] + target_metric = train_common_params['target_metric'].replace('metrics.','') ################################################ # backbone_models model_tabular = FuseModelTabularContinuousCategorical( @@ -27,6 +29,12 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an model_projection_imaging = train_common_params['imaging_projector'] model_projection_tabular = train_common_params['tabular_projector'] + model_interactive_3d = FuseBackboneResnet3DInteractive( + conv_inputs=(('data.image', 1),), + fcn_inputs=(('data.input.clinical.all', 1),), + ) + + heads_for_multimodal = { 'multimodal_head': [ @@ -34,7 +42,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an head_name='multimodal', conv_inputs=(('model.multimodal_features', train_common_params['tabular_feature_size'] * 2),), - num_classes=2, + num_classes=num_classes, ) ], 'tabular_head': @@ -43,7 +51,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an head_name='tabular', conv_inputs=(('model.tabular_features', train_common_params['tabular_feature_size']),), - num_classes=2, + num_classes=num_classes, ) ], 'imaging_head': @@ -54,11 +62,35 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an layers_description=(256,), conv_inputs=(('model.imaging_features', train_common_params['imaging_feature_size']),), - num_classes=2, + num_classes=num_classes, pooling="avg", ) ], + 'imaging_head_3d': + [ + FuseHead3dClassifier( + head_name='imaging', + dropout_rate=0.5, + layers_description=(256,), + conv_inputs=(('model.imaging_features', + train_common_params['imaging_feature_size']),), + num_classes=num_classes, + ) + ], + + 'interactive_head_3d': + [ + FuseHead3dClassifier( + head_name='interactive', + dropout_rate=0.5, + layers_description=(256,), + conv_inputs=(('model.backbone_features', + train_common_params['imaging_feature_size']),), + num_classes=num_classes, + ) + ], + } loss_for_multimodal = { @@ -68,6 +100,8 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an callable=F.cross_entropy, weight=1.0), 'imaging_loss':FuseLossDefault(pred_name='model.logits.imaging', target_name='data.gt', callable=F.cross_entropy, weight=1.0), + 'interactive_loss': FuseLossDefault(pred_name='model.logits.interactive', target_name='data.gt', + callable=F.cross_entropy, weight=1.0), 'ensemble_loss':FuseLossDefault(pred_name='model.output.tabular_ensemble_average', target_name='data.gt', callable=F.nll_loss, weight=1.0,reduction='sum'), @@ -76,6 +110,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'multimodal_auc': MetricAUCROC(pred='model.output.multimodal', target='data.gt'), 'tabular_auc': MetricAUCROC(pred='model.output.tabular', target='data.gt'), 'imaging_auc': MetricAUCROC(pred='model.output.imaging', target='data.gt'), + 'interactive_auc': MetricAUCROC(pred='model.output.interactive', target='data.gt'), 'ensemble_auc':MetricAUCROC(pred='model.output.tabular_ensemble_average', target='data.gt'), } ################################################ @@ -92,7 +127,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'cls_loss': loss_for_multimodal['tabular_loss'], } train_common_params['metrics'] = { - 'auc': metric_for_multimodal['tabular_auc'], + target_metric: metric_for_multimodal['tabular_auc'], } train_common_params['manager.learning_rate'] = 1e-4 @@ -106,19 +141,21 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an weight_decay=train_common_params['manager.weight_decay']) train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], gamma=train_common_params['manager.gamma']) + analyze_common_params['metrics'] = train_common_params['metrics'] + infer_common_params['output_keys'] = ['model.output.tabular', 'data.gt'] if train_common_params['fusion_type'] == 'mono_imaging': train_common_params['model'] = FuseMultiModalityModel( imaging_inputs=(('data.image', 1),), imaging_backbone=model_imaging, - heads=heads_for_multimodal['imaging_head'], + heads=heads_for_multimodal['imaging_head_3d'], ) train_common_params['loss'] = { 'cls_loss': loss_for_multimodal['imaging_loss'], } train_common_params['metrics'] = { - 'auc': metric_for_multimodal['imaging_auc'], + target_metric: metric_for_multimodal['imaging_auc'], } train_common_params['manager.learning_rate'] = 1e-5 train_common_params['manager.weight_decay'] = 0.001 @@ -126,6 +163,8 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an train_common_params['optimizer'] = optim.Adam(train_common_params['model'].parameters(), lr=train_common_params['manager.learning_rate'], weight_decay=train_common_params['manager.weight_decay']) train_common_params['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(train_common_params['optimizer']) + analyze_common_params['metrics'] = train_common_params['metrics'] + infer_common_params['output_keys'] = ['model.output.imaging', 'data.gt'] if train_common_params['fusion_type'] == 'late_fusion': @@ -144,7 +183,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'cls_loss': loss_for_multimodal['multimodal_loss'], } train_common_params['metrics'] = { - 'auc': metric_for_multimodal['multimodal_auc'], + target_metric: metric_for_multimodal['multimodal_auc'], } train_common_params['manager.learning_rate'] = 1e-4 train_common_params['manager.weight_decay'] = 1e-4 @@ -158,10 +197,15 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], gamma=train_common_params['manager.gamma']) + analyze_common_params['metrics'] = train_common_params['metrics'] = { + target_metric: metric_for_multimodal['multimodal_auc'], + } + infer_common_params['output_keys'] = ['model.output.multimodal', 'data.gt'] + if train_common_params['fusion_type'] == 'ensemble': - train_common_params['tabular_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/model_mg_radiologist_usa/mono_tabular/' - train_common_params['imaging_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/model_mg_radiologist_usa/mono_imaging_no_aug/' + train_common_params['tabular_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/knight/mono_tabular/' + train_common_params['imaging_dir'] = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/knight/mono_imaging/' train_common_params['model'] = FuseModelEnsemble(input_model_dirs=[train_common_params['tabular_dir'], train_common_params['imaging_dir']]) infer_common_params['model_dir'] = [train_common_params['tabular_dir'], @@ -175,10 +219,9 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an 'model.output.tabular_ensemble_majority_vote'] analyze_common_params['metrics'] = train_common_params['metrics'] = { - 'auc': metric_for_multimodal['ensemble_auc'], + target_metric: metric_for_multimodal['ensemble_auc'], } - if train_common_params['fusion_type'] == 'contrastive': train_common_params['model'] = FuseMultiModalityModel( tabular_inputs=(('data.continuous', 1), ('data.categorical', 1),), @@ -210,4 +253,33 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an train_common_params['scheduler'] = optim.lr_scheduler.StepLR(train_common_params['optimizer'], step_size=train_common_params['manager.step_size'], gamma=train_common_params['manager.gamma']) + + + if train_common_params['fusion_type'] == 'interactive': + + train_common_params['model'] = FuseModelDefaultInteractive(backbone=FuseBackboneResnet3DInteractive( + conv_inputs=(('data.image', 1),), + fcn_inputs=(('data.input.clinical.all', 1),), + ), + heads=heads_for_multimodal['interactive_head_3d'], + ) + + train_common_params['loss'] = { + 'cls_loss': loss_for_multimodal['interactive_loss'], + } + train_common_params['metrics'] = { + target_metric: metric_for_multimodal['interactive_auc'], + } + train_common_params['manager.learning_rate'] = 1e-5 + train_common_params['manager.weight_decay'] = 0.001 + + train_common_params['optimizer'] = optim.Adam(train_common_params['model'].parameters(), lr=train_common_params['manager.learning_rate'], + weight_decay=train_common_params['manager.weight_decay']) + train_common_params['scheduler'] = optim.lr_scheduler.ReduceLROnPlateau(train_common_params['optimizer']) + analyze_common_params['metrics'] = train_common_params['metrics'] + infer_common_params['output_keys'] = ['model.output.interactive', 'data.gt'] + + + + return train_common_params,infer_common_params,analyze_common_params \ No newline at end of file diff --git a/fuse_examples/classification/multimodality/runner.py b/fuse_examples/classification/multimodality/runner.py index 7bbcf6a07..8835e96b7 100644 --- a/fuse_examples/classification/multimodality/runner.py +++ b/fuse_examples/classification/multimodality/runner.py @@ -64,16 +64,16 @@ ########################################## dataset_name = 'mg_radiologic' root = '' -root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder +root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/mg/' # TODO: add path to the data folder assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" # Name of the experiment -experiment = 'contrastive_non_annotated' +experiment = 'late_fusion2' # Path to cache data cache_path = root_data+'/mg_radiologic/' paths = multimodal_paths(dataset_name, root_data, root, experiment, cache_path) TRAIN_COMMON_PARAMS['paths'] = paths -TRAIN_COMMON_PARAMS['fusion_type'] = 'contrastive' +TRAIN_COMMON_PARAMS['fusion_type'] = 'late_fusion' ###################################### # Inference Common Params ###################################### @@ -81,7 +81,8 @@ INFER_COMMON_PARAMS['infer_filename'] = os.path.join(TRAIN_COMMON_PARAMS['paths']['inference_dir'],'validation_set_infer.gz') INFER_COMMON_PARAMS['checkpoint'] = 'best' # Fuse TIP: possible values are 'best', 'last' or epoch_index. INFER_COMMON_PARAMS['data.train_num_workers'] = TRAIN_COMMON_PARAMS['data.train_num_workers'] - +INFER_COMMON_PARAMS['output_keys'] = ['model.output.multimodal','data.gt'] +INFER_COMMON_PARAMS['model_dir'] = TRAIN_COMMON_PARAMS['paths']['model_dir'] # Analyze Common Params ###################################### ANALYZE_COMMON_PARAMS = {} @@ -106,8 +107,8 @@ # best_epoch_source # if an epoch values are the best so far, the epoch is saved as a checkpoint. TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = { - 'source':'losses.cls_loss',#'metrics.auc',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries - 'optimization': 'min', # can be either min/max + 'source':'metrics.auc',#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries + 'optimization': 'max', # can be either min/max 'on_equal_values': 'better', # can be either better/worse - whether to consider best epoch when values are equal } @@ -123,7 +124,7 @@ ['gt'], features_dic['annotated_feat'], features_dic['non_annotated_feat'], - use_annotated=False, use_non_annotated=True) + use_annotated=True, use_non_annotated=True) #define processors TRAIN_COMMON_PARAMS['imaging_processor'] = FuseMGInputProcessor @@ -340,11 +341,11 @@ def run_analyze(paths: dict, analyze_common_params: dict): # run # FIXME: simplify analyze interface for this case - results = analyzer.eval(gt_processors={}, - data_pickle_filename=os.path.join(paths["inference_dir"], + results = analyzer.eval(ids=None, + data=os.path.join(paths["inference_dir"], analyze_common_params["infer_filename"]), metrics=metrics, - output_filename=analyze_common_params['output_filename']) + output_dir=analyze_common_params['output_filename']) return results @@ -362,7 +363,7 @@ def run_analyze(paths: dict, analyze_common_params: dict): force_gpus = None # [0] FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus) - RUNNING_MODES = ['train']#['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze' + RUNNING_MODES = ['train','infer','analyze']#['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze' paths = TRAIN_COMMON_PARAMS['paths'] diff --git a/fuse_examples/classification/multimodality/runner_knight.py b/fuse_examples/classification/multimodality/runner_knight.py new file mode 100644 index 000000000..8c652d89a --- /dev/null +++ b/fuse_examples/classification/multimodality/runner_knight.py @@ -0,0 +1,389 @@ +""" + +(C) Copyright 2021 IBM Corp. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +Created on June 30, 2021 + +""" + +import os +import logging +from torch.utils.data.dataloader import DataLoader +import torch.nn as nn + +import fuse.utils.gpu as FuseUtilsGPU +from fuse.utils.utils_debug import FuseUtilsDebug +from fuse.utils.utils_logger import fuse_logger_start +from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback +from fuse.managers.callbacks.callback_metric_statistics import FuseMetricStatisticsCallback +from fuse.managers.callbacks.callback_time_statistics import FuseTimeStatisticsCallback +from fuse.managers.manager_default import FuseManagerDefault +from fuse.models.backbones.backbone_resnet_3d import FuseBackboneResnet3D + +from fuse.models.backbones.backbone_mlp import FuseMultilayerPerceptronBackbone +from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch +from fuse.eval.evaluator import EvaluatorDefault + +from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters +from fuse_examples.classification.knight.baseline.dataset import knight_dataset + +from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths +from fuse_examples.classification.multimodality.model_tabular_imaging import project_imaging, project_tabular + +def tabular_feature_knight(): + features_dict = {} + + features_dict['continuous_clinical_feat'] = ['smoking_history', 'comorbidities', 'gender'] #14 continuous clinical features + features_dict['categorical_clinical_feat'] = ['last_preop_egfr', 'radiographic_size', 'body_mass_index','age_at_nephrectomy'] #63 categorical clinical features + + return features_dict + + +########################################## +# Debug modes +########################################## +mode = 'default' # Options: 'default', 'fast', 'debug', 'verbose', 'user'. See details in FuseUtilsDebug +debug = FuseUtilsDebug(mode) + +########################################## +# Train Common Params +########################################## +# ============ +# Data +# ============ + +TRAIN_COMMON_PARAMS = {} + +TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8 +TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8 + +########################################## +# Dataset +########################################## +dataset_name = 'knight' +TRAIN_COMMON_PARAMS['task_num'] = 1 +root = '' +root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder +assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" +# Name of the experiment +experiment = 'interactive' +# Path to cache data +cache_path = root_data+'/knight/cache_knight_256_256_64/' + +paths = multimodal_paths(dataset_name, root_data, root, experiment, cache_path) +TRAIN_COMMON_PARAMS['paths'] = paths +TRAIN_COMMON_PARAMS['fusion_type'] = 'interactive' +###################################### +# Inference Common Params +###################################### +INFER_COMMON_PARAMS = {} +INFER_COMMON_PARAMS['infer_filename'] = os.path.join(TRAIN_COMMON_PARAMS['paths']['inference_dir'],'validation_set_infer.gz') +INFER_COMMON_PARAMS['checkpoint'] = 'best' # Fuse TIP: possible values are 'best', 'last' or epoch_index. +INFER_COMMON_PARAMS['data.train_num_workers'] = TRAIN_COMMON_PARAMS['data.train_num_workers'] +INFER_COMMON_PARAMS['model_dir'] = TRAIN_COMMON_PARAMS['paths']['model_dir'] +# Analyze Common Params +###################################### +ANALYZE_COMMON_PARAMS = {} +ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename'] +ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(TRAIN_COMMON_PARAMS['paths']['inference_dir'],'all_metrics.txt') +ANALYZE_COMMON_PARAMS['num_workers'] = 4 +ANALYZE_COMMON_PARAMS['batch_size'] = 2 + +#---------------------- +# task related parameters + +if TRAIN_COMMON_PARAMS['task_num'] == 1: + TRAIN_COMMON_PARAMS['num_classes'] = 2 + TRAIN_COMMON_PARAMS['target_name']='data.gt.gt_global.task_1_label' + TRAIN_COMMON_PARAMS['target_metric']='metrics.auc' + +elif TRAIN_COMMON_PARAMS['task_num'] == 2: + TRAIN_COMMON_PARAMS['num_classes'] = 5 + TRAIN_COMMON_PARAMS['target_name']='data.gt.gt_global.task_2_label' + TRAIN_COMMON_PARAMS['target_metric']='metrics.auc.macro_avg' + +# =============== +# Manager - Train +# =============== +NUM_GPUS = 2 +TRAIN_COMMON_PARAMS['data.batch_size'] =2 * NUM_GPUS +TRAIN_COMMON_PARAMS['manager.train_params'] = { + 'num_gpus': NUM_GPUS, + 'num_epochs': 300, + 'virtual_batch_size': 1, # number of batches in one virtual batch + 'start_saving_epochs': 10, # fyirst epoch to start saving checkpoints from + 'gap_between_saving_epochs': 5, # number of epochs between saved checkpoint +} + +# best_epoch_source +# if an epoch values are the best so far, the epoch is saved as a checkpoint. +TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = { + 'source':TRAIN_COMMON_PARAMS['target_metric'],#'losses.cls_loss',# 'metrics.auc.macro_avg', # can be any key from losses or metrics dictionaries + 'optimization': 'max', # can be either min/max + 'on_equal_values': 'better', + # can be either better/worse - whether to consider best epoch when values are equal +} + +TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None + + + +#define postprocessing function +features_dic = tabular_feature_knight() +TRAIN_COMMON_PARAMS['post_processing'] = None + +#define encoders +TRAIN_COMMON_PARAMS['imaging_feature_size'] = 512 +TRAIN_COMMON_PARAMS['tabular_feature_size'] = 256 +TRAIN_COMMON_PARAMS['tabular_encoder_categorical'] = FuseMultilayerPerceptronBackbone( + layers=[128, 3+len(features_dic['categorical_clinical_feat'])], + mlp_input_size=3+len(features_dic['categorical_clinical_feat'])) +TRAIN_COMMON_PARAMS['tabular_encoder_continuous'] = None +TRAIN_COMMON_PARAMS['tabular_encoder_cat'] = FuseMultilayerPerceptronBackbone( + layers=[TRAIN_COMMON_PARAMS['tabular_feature_size']], + mlp_input_size=4+len(features_dic['categorical_clinical_feat'])+\ + len(features_dic['continuous_clinical_feat'])) + + + +TRAIN_COMMON_PARAMS['imaging_encoder'] = FuseBackboneResnet3D(in_channels=1) +TRAIN_COMMON_PARAMS['imaging_projector'] = project_imaging(projection_imaging=nn.Conv2d(TRAIN_COMMON_PARAMS['imaging_feature_size'], TRAIN_COMMON_PARAMS['tabular_feature_size'], kernel_size=1, stride=1),dim='3d') +TRAIN_COMMON_PARAMS['tabular_projector'] = None + +TRAIN_COMMON_PARAMS['dataset_func'] = knight_dataset( + data_dir=TRAIN_COMMON_PARAMS['paths']['data_dir'], + cache_dir=TRAIN_COMMON_PARAMS['paths']['cache_dir'], + split=TRAIN_COMMON_PARAMS['paths']['split'], + reset_cache=TRAIN_COMMON_PARAMS['paths']['force_reset_model_dir'], + rand_gen=TRAIN_COMMON_PARAMS['paths']['rand_gen'], + batch_size=TRAIN_COMMON_PARAMS['data.batch_size'], + resize_to=(256, 256, 64), + task_num=TRAIN_COMMON_PARAMS['task_num'], + target_name=TRAIN_COMMON_PARAMS['target_name'], + num_classes=TRAIN_COMMON_PARAMS['num_classes']) + + + +TRAIN_COMMON_PARAMS,INFER_COMMON_PARAMS,ANALYZE_COMMON_PARAMS = multimodal_parameters(TRAIN_COMMON_PARAMS,INFER_COMMON_PARAMS,ANALYZE_COMMON_PARAMS) + +################################# +# Train Template +################################# +def run_train(paths: dict, train_common_params: dict, reset_cache: bool): + # ============================================================================== + # Logger + # ============================================================================== + fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + + # Download data + # TBD + + lgr.info('\nFuse Train', {'attrs': ['bold', 'underline']}) + + lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'}) + lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'}) + + #### Train Data + lgr.info(f'Train Data:', {'attrs': 'bold'}) + + train_dataloader, validation_dataloader, test_dataloader, train_dataset, validation_dataset, test_dataset = train_common_params['dataset_func'] + ## Create sampler + lgr.info(f'- Create sampler:') + sampler = FuseSamplerBalancedBatch(dataset=train_dataset, + balanced_class_name='data.gt', + num_balanced_classes = train_common_params['num_classes'], + balanced_class_probs=[1.0 / train_common_params['num_classes']] * train_common_params['num_classes'] if train_common_params['task_num'] == 2 else None, + batch_size=train_common_params['data.batch_size'], + use_dataset_cache=True) + + lgr.info(f'- Create sampler: Done') + + ## Create dataloader + train_dataloader = DataLoader(dataset=train_dataset, + shuffle=False, drop_last=False, + batch_sampler=sampler, collate_fn=train_dataset.collate_fn, + num_workers=train_common_params['data.train_num_workers']) + lgr.info(f'Train Data: Done', {'attrs': 'bold'}) + + #### Validation data + lgr.info(f'Validation Data:', {'attrs': 'bold'}) + + ## Create dataloader + validation_dataloader = DataLoader(dataset=validation_dataset, + shuffle=False, + drop_last=False, + batch_sampler=None, + batch_size=train_common_params['data.batch_size'], + num_workers=train_common_params['data.validation_num_workers'], + collate_fn=validation_dataset.collate_fn) + lgr.info(f'Validation Data: Done', {'attrs': 'bold'}) + + + # =================================================================== + # ============================================================================== + # Model + # ============================================================================== + lgr.info('Model:', {'attrs': 'bold'}) + + model = train_common_params['model'] + + lgr.info('Model: Done', {'attrs': 'bold'}) + + # ==================================================================================== + # Loss + # ==================================================================================== + + losses = train_common_params['loss'] + + # ==================================================================================== + # Metrics + # ==================================================================================== + + metrics=train_common_params['metrics'] + + # ===================================================================================== + # Callbacks + # ===================================================================================== + callbacks = [ + # default callbacks + FuseTensorboardCallback(model_dir=paths['model_dir']), # save statistics for tensorboard + FuseMetricStatisticsCallback(output_path=paths['model_dir'] + "/metrics.csv"), # save statistics in a csv file + FuseTimeStatisticsCallback(num_epochs=train_common_params['manager.train_params']['num_epochs'], + load_expected_part=0.1) # time profiler + ] + + # ===================================================================================== + # Manager - Train + # Create a manager, training objects and run a training process. + # ===================================================================================== + lgr.info('Train:', {'attrs': 'bold'}) + + # create optimizer + optimizer = TRAIN_COMMON_PARAMS['optimizer'] + # create scheduler + scheduler = TRAIN_COMMON_PARAMS['scheduler'] + + # train from scratch + manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir']) + # Providing the objects required for the training process. + manager.set_objects(net=model, + optimizer=optimizer, + losses=losses, + metrics=metrics, + best_epoch_source=train_common_params['manager.best_epoch_source'], + lr_scheduler=scheduler, + callbacks=callbacks, + train_params=train_common_params['manager.train_params'], + output_model_dir=paths['model_dir']) + + # Continue training + if train_common_params['manager.resume_checkpoint_filename'] is not None: + # Loading the checkpoint including model weights, learning rate, and epoch_index. + manager.load_checkpoint(checkpoint=train_common_params['manager.resume_checkpoint_filename'], mode='train', + values_to_resume=['net']) + # # Start training + manager.train(train_dataloader=train_dataloader, + validation_dataloader=validation_dataloader) + + lgr.info('Train: Done', {'attrs': 'bold'}) + + +###################################### +# Inference Template +###################################### +def run_infer(paths: dict, train_common_params: dict,infer_common_params: dict): + #### Logger + fuse_logger_start(output_path=paths['inference_dir'], console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + lgr.info('Fuse Inference', {'attrs': ['bold', 'underline']}) + lgr.info(f'infer_filename={os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])}', + {'color': 'magenta'}) + + # Create data source: + train_dataloader, validation_dataloader, test_dataloader, \ + train_dataset, validation_dataset, test_dataset= train_common_params['dataset_func'] + + ## Create dataloader + infer_dataloader = DataLoader(dataset=validation_dataset, + shuffle=False, drop_last=False, + collate_fn=validation_dataset.collate_fn, + num_workers=infer_common_params['data.train_num_workers']) + lgr.info(f'Test Data: Done', {'attrs': 'bold'}) + + #### Manager for inference + manager = FuseManagerDefault() + # extract just the global classification per sample and save to a file + output_columns = infer_common_params['output_keys'] + manager.infer(data_loader=infer_dataloader, + input_model_dir=infer_common_params['model_dir'], + checkpoint=infer_common_params['checkpoint'], + output_columns=output_columns, + output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])) + + ###################################### + + + +###################################### +# Analyze Template +###################################### +def run_analyze(paths: dict, analyze_common_params: dict): + fuse_logger_start(output_path=None, console_verbose_level=logging.INFO) + lgr = logging.getLogger('Fuse') + lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']}) + + # metrics + metrics = analyze_common_params['metrics'] + + + # create analyzer + analyzer = EvaluatorDefault() + + # run + # FIXME: simplify analyze interface for this case + results = analyzer.eval(ids=None, + data=os.path.join(paths["inference_dir"], + analyze_common_params["infer_filename"]), + metrics=metrics, + output_dir=analyze_common_params['output_filename']) + + return results + + +###################################### +# Run +###################################### + + +if __name__ == "__main__": + # allocate gpus + if NUM_GPUS == 0: + TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu' + # uncomment if you want to use specific gpus instead of automatically looking for free ones + force_gpus = None # [0] + FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus) + + RUNNING_MODES = ['train','infer','analyze']#['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze' + + + paths = TRAIN_COMMON_PARAMS['paths'] + # train + if 'train' in RUNNING_MODES: + run_train(paths=paths, train_common_params=TRAIN_COMMON_PARAMS, reset_cache=False) + + # infer + if 'infer' in RUNNING_MODES: + run_infer(paths=paths, train_common_params=TRAIN_COMMON_PARAMS,infer_common_params=INFER_COMMON_PARAMS) + # + # analyze + if 'analyze' in RUNNING_MODES: + run_analyze(paths=paths, analyze_common_params=ANALYZE_COMMON_PARAMS) From d08bf244bc3d84034c8be151a6a825c907a38122 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Wed, 1 Jun 2022 14:18:30 +0300 Subject: [PATCH 6/7] knight multimodal example --- .../multimodality/dataset_knight.py | 273 ++++++++++++++++++ .../multimodality/runner_knight.py | 8 +- 2 files changed, 277 insertions(+), 4 deletions(-) create mode 100644 fuse_examples/classification/multimodality/dataset_knight.py diff --git a/fuse_examples/classification/multimodality/dataset_knight.py b/fuse_examples/classification/multimodality/dataset_knight.py new file mode 100644 index 000000000..94a71815f --- /dev/null +++ b/fuse_examples/classification/multimodality/dataset_knight.py @@ -0,0 +1,273 @@ +import os +from functools import partial +from fuse.data.visualizer.visualizer_default_3d import Fuse3DVisualizerDefault +from fuse.data.augmentor.augmentor_default import FuseAugmentorDefault +from fuse.data.augmentor.augmentor_toolbox import aug_op_affine, aug_op_gaussian +from fuse.data.dataset.dataset_default import FuseDatasetDefault +from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch + +from fuse.utils.rand.param_sampler import Uniform, RandInt, RandBool +from torch.utils.data.dataloader import DataLoader +from fuse_examples.classification.knight.baseline.input_processor import KiTSBasicInputProcessor +from fuse.data.data_source.data_source_default import FuseDataSourceDefault + +from fuse.data.augmentor.augmentor_toolbox import rotation_in_3d, squeeze_3d_to_2d, unsqueeze_2d_to_3d +from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict +import torch +from fuse_examples.classification.knight.baseline.clinical_processor import KiCClinicalProcessor, KiCGTProcessor + + +def prepare_clinical(sample_dict: dict,target_name: str) -> dict: + age = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.age_at_nephrectomy') + if age != None and age > 0 and age < 120: + age = torch.tensor(age / 120.0).reshape(-1) + else: + age = torch.tensor(-1.0).reshape(-1) + + bmi = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.body_mass_index') + if bmi != None and bmi > 10 and bmi < 100: + bmi = torch.tensor(bmi / 50.0).reshape(-1) + else: + bmi = torch.tensor(-1.0).reshape(-1) + + radiographic_size = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.radiographic_size') + if radiographic_size != None and radiographic_size > 0 and radiographic_size < 50: + radiographic_size = torch.tensor(radiographic_size / 15.0).reshape(-1) + else: + radiographic_size = torch.tensor(-1.0).reshape(-1) + + preop_egfr = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.last_preop_egfr') + if preop_egfr != None and preop_egfr > 0 and preop_egfr < 200: + preop_egfr = torch.tensor(preop_egfr / 90.0).reshape(-1) + else: + preop_egfr = torch.tensor(-1.0).reshape(-1) + # turn categorical features into one hot vectors + gender = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.gender') + gender_one_hot = torch.zeros(len(GENDER_INDEX)) + if gender in GENDER_INDEX.values(): + gender_one_hot[gender] = 1 + + comorbidities = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.comorbidities') + comorbidities_one_hot = torch.zeros(len(COMORBIDITIES_INDEX)) + if comorbidities in COMORBIDITIES_INDEX.values(): + comorbidities_one_hot[comorbidities] = 1 + + smoking_history = FuseUtilsHierarchicalDict.get(sample_dict, 'data.input.clinical.smoking_history') + smoking_history_one_hot = torch.zeros(len(SMOKE_HISTORY_INDEX)) + if smoking_history in SMOKE_HISTORY_INDEX.values(): + smoking_history_one_hot[smoking_history] = 1 + + clinical_encoding = torch.cat( + (age, bmi, radiographic_size, preop_egfr, gender_one_hot, comorbidities_one_hot, smoking_history_one_hot), + dim=0) + FuseUtilsHierarchicalDict.set(sample_dict, "data.input.clinical.all", clinical_encoding) + continuous_clinical_feat = torch.cat((age, bmi, radiographic_size, preop_egfr), dim=0) + categorical_clinical_feat = torch.cat((gender_one_hot, comorbidities_one_hot, smoking_history_one_hot), dim=0) + # change fields to fit multimodality pipeline + FuseUtilsHierarchicalDict.set(sample_dict, "data.continuous", continuous_clinical_feat) + FuseUtilsHierarchicalDict.set(sample_dict, "data.categorical", categorical_clinical_feat) + + # fix + img = FuseUtilsHierarchicalDict.get(sample_dict, "data.input.image") + FuseUtilsHierarchicalDict.set(sample_dict, "data.image", img) + gt = FuseUtilsHierarchicalDict.get(sample_dict, target_name) + FuseUtilsHierarchicalDict.set(sample_dict, "data.gt", gt) + return sample_dict + + +def knight_dataset(data_dir: str = 'data', cache_dir: str = 'cache', split: dict = None, \ + reset_cache: bool = False, \ + rand_gen=None, batch_size=8, resize_to=(256, 256, 110), task_num=1, \ + target_name='data.gt.gt_global.task_1_label', num_classes=2, only_labels=False): + augmentation_pipeline = [ + [ + ("data.input.image",), + rotation_in_3d, + { + "z_rot": Uniform(-5.0, 5.0), + "y_rot": Uniform(-5.0, 5.0), + "x_rot": Uniform(-5.0, 5.0), + }, + {"apply": RandBool(0.5)}, + ], + [("data.input.image",), squeeze_3d_to_2d, {"axis_squeeze": "z"}, {}], + [ + ("data.input.image",), + aug_op_affine, + { + "rotate": Uniform(0, 360.0), + "translate": (RandInt(-14, 14), RandInt(-14, 14)), + "flip": (RandBool(0.5), RandBool(0.5)), + "scale": Uniform(0.9, 1.1), + }, + {"apply": RandBool(0.9)}, + ], + [ + ("data.input.image",), + aug_op_gaussian, + {"std": 0.01}, + {"apply": RandBool(0.9)}, + ], + [ + ("data.input.image",), + unsqueeze_2d_to_3d, + {"channels": 1, "axis_squeeze": "z"}, + {}, + ], + ] + + if 'train' in split: + train_data_source = FuseDataSourceDefault(list(split['train'])) + image_dir = os.path.join(data_dir, 'knight', 'data') + json_filepath = os.path.join(image_dir, 'knight.json') + gt_processors = { + 'gt_global': KiCGTProcessor(json_filename=json_filepath, + columns_to_tensor={'task_1_label': torch.long, 'task_2_label': torch.long}) + } + else: # split can contain BOTH 'train' and 'val', or JUST 'test' + image_dir = os.path.join(data_dir, 'images') + json_filepath = os.path.join(data_dir, 'features.json') + if only_labels: + json_labels_filepath = os.path.join(data_dir, 'knight_test_labels.json') + gt_processors = { + 'gt_global': KiCGTProcessor(json_filename=json_labels_filepath, + columns_to_tensor={'task_1_label': torch.long, 'task_2_label': torch.long}, + test_labels=True) + } + else: + gt_processors = {} + + if only_labels: + # just labels - no need to load and process input + input_processors = {} + post_processing_func = None + else: + # we use the same processor for the clinical data and ground truth, since both are in the .csv file + # need to make sure to discard the label column from the data when using it as input + input_processors = { + 'image': KiTSBasicInputProcessor(input_data=image_dir, resize_to=resize_to), + 'clinical': KiCClinicalProcessor(json_filename=json_filepath) + } + post_processing_func = partial(prepare_clinical, target_name=target_name) +# prepare_clinical(target_name) + + # Create data augmentation (optional) + augmentor = FuseAugmentorDefault( + augmentation_pipeline=augmentation_pipeline) + + # Create visualizer (optional) + visualizer = Fuse3DVisualizerDefault(image_name='data.input.image', label_name=target_name) + # Create dataset + if 'train' in split: + train_dataset = FuseDatasetDefault(cache_dest=cache_dir, + data_source=train_data_source, + input_processors=input_processors, + gt_processors=gt_processors, + post_processing_func=post_processing_func, + augmentor=augmentor, + visualizer=visualizer) + + print(f'- Load and cache data:') + train_dataset.create(reset_cache=reset_cache) + + print(f'- Load and cache data: Done') + + ## Create sampler + print(f'- Create sampler:') + sampler = FuseSamplerBalancedBatch(dataset=train_dataset, + balanced_class_name=target_name, + num_balanced_classes=num_classes, + batch_size=batch_size, + balanced_class_probs=[ + 1.0 / num_classes] * num_classes if task_num == 2 else None, + use_dataset_cache=False) # we don't want to use_dataset_cache here since it's more + # costly to read all cached data then simply the CSV file + # which contains the labels + + print(f'- Create sampler: Done') + + ## Create dataloader + train_dataloader = DataLoader(dataset=train_dataset, + shuffle=False, drop_last=False, + batch_sampler=sampler, collate_fn=train_dataset.collate_fn, + num_workers=8, generator=rand_gen) + print(f'Train Data: Done', {'attrs': 'bold'}) + + #### Validation data + print(f'Validation Data:', {'attrs': 'bold'}) + + ## Create data source + validation_data_source = FuseDataSourceDefault(list(split['val'])) + + ## Create dataset + validation_dataset = FuseDatasetDefault(cache_dest=cache_dir, + data_source=validation_data_source, + input_processors=input_processors, + gt_processors=gt_processors, + post_processing_func=post_processing_func, + visualizer=visualizer) + + print(f'- Load and cache data:') + validation_dataset.create( + pool_type='thread') # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading + print(f'- Load and cache data: Done') + + ## Create dataloader + validation_dataloader = DataLoader(dataset=validation_dataset, + shuffle=False, + drop_last=False, + batch_sampler=None, + batch_size=batch_size, + num_workers=8, + collate_fn=validation_dataset.collate_fn, + generator=rand_gen) + print(f'Validation Data: Done', {'attrs': 'bold'}) + test_dataloader = test_dataset = None + else: # test only + #### Test data + print(f'Test Data:', {'attrs': 'bold'}) + + ## Create data source + test_data_source = FuseDataSourceDefault(list(split['test'])) + + ## Create dataset + test_dataset = FuseDatasetDefault(cache_dest=cache_dir, + data_source=test_data_source, + input_processors=input_processors, + gt_processors=gt_processors, + post_processing_func=post_processing_func, + visualizer=visualizer) + + print(f'- Load and cache data:') + test_dataset.create( + pool_type='thread') # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading + print(f'- Load and cache data: Done') + + ## Create dataloader + test_dataloader = DataLoader(dataset=test_dataset, + shuffle=False, + drop_last=False, + batch_sampler=None, + batch_size=batch_size, + num_workers=8, + collate_fn=test_dataset.collate_fn, + generator=rand_gen) + print(f'Test Data: Done', {'attrs': 'bold'}) + train_dataloader = train_dataset = validation_dataloader = validation_dataset = None + return train_dataloader, validation_dataloader, test_dataloader, \ + train_dataset, validation_dataset, test_dataset + + +GENDER_INDEX = { + 'male': 0, + 'female': 1 +} +COMORBIDITIES_INDEX = { + 'no comorbidities': 0, + 'comorbidities exist': 1 +} +SMOKE_HISTORY_INDEX = { + 'never smoked': 0, + 'previous smoker': 1, + 'current smoker': 2 +} diff --git a/fuse_examples/classification/multimodality/runner_knight.py b/fuse_examples/classification/multimodality/runner_knight.py index 8c652d89a..de8bb725b 100644 --- a/fuse_examples/classification/multimodality/runner_knight.py +++ b/fuse_examples/classification/multimodality/runner_knight.py @@ -33,10 +33,10 @@ from fuse.eval.evaluator import EvaluatorDefault from fuse_examples.classification.multimodality.multimodel_parameters import multimodal_parameters -from fuse_examples.classification.knight.baseline.dataset import knight_dataset - from fuse_examples.classification.multimodality.multimodal_paths import multimodal_paths from fuse_examples.classification.multimodality.model_tabular_imaging import project_imaging, project_tabular +from fuse_examples.classification.multimodality.dataset_knight import knight_dataset + def tabular_feature_knight(): features_dict = {} @@ -74,13 +74,13 @@ def tabular_feature_knight(): root_data = '/projects/msieve_dev3/usr/Tal/my_research/multi-modality/' # TODO: add path to the data folder assert root_data is not None, "Error: please set root_data, the path to the stored MM dataset location" # Name of the experiment -experiment = 'interactive' +experiment = 'mono_tabular_' # Path to cache data cache_path = root_data+'/knight/cache_knight_256_256_64/' paths = multimodal_paths(dataset_name, root_data, root, experiment, cache_path) TRAIN_COMMON_PARAMS['paths'] = paths -TRAIN_COMMON_PARAMS['fusion_type'] = 'interactive' +TRAIN_COMMON_PARAMS['fusion_type'] = 'mono_tabular' ###################################### # Inference Common Params ###################################### From de1d355b81bec64395f2ef67880944eb0dbfe456 Mon Sep 17 00:00:00 2001 From: ttlusty Date: Wed, 1 Jun 2022 14:21:02 +0300 Subject: [PATCH 7/7] knight multimodal example --- .../multimodality/multimodel_parameters.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/fuse_examples/classification/multimodality/multimodel_parameters.py b/fuse_examples/classification/multimodality/multimodel_parameters.py index 26c066183..4ef7b127c 100644 --- a/fuse_examples/classification/multimodality/multimodel_parameters.py +++ b/fuse_examples/classification/multimodality/multimodel_parameters.py @@ -30,10 +30,9 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an model_projection_tabular = train_common_params['tabular_projector'] model_interactive_3d = FuseBackboneResnet3DInteractive( - conv_inputs=(('data.image', 1),), - fcn_inputs=(('data.input.clinical.all', 1),), - ) - + conv_inputs=(('data.image', 1),), + fcn_inputs=(('data.input.clinical.all', 1),), + ) heads_for_multimodal = { 'multimodal_head': @@ -257,10 +256,7 @@ def multimodal_parameters(train_common_params: dict,infer_common_params: dict,an if train_common_params['fusion_type'] == 'interactive': - train_common_params['model'] = FuseModelDefaultInteractive(backbone=FuseBackboneResnet3DInteractive( - conv_inputs=(('data.image', 1),), - fcn_inputs=(('data.input.clinical.all', 1),), - ), + train_common_params['model'] = FuseModelDefaultInteractive(backbone=model_interactive_3d, heads=heads_for_multimodal['interactive_head_3d'], )