From 97a97f034a0eb755c56facbd69b4df359562a3f9 Mon Sep 17 00:00:00 2001 From: bronval Date: Tue, 23 Jan 2024 14:15:14 +0100 Subject: [PATCH] fix shape mismatch in function inverse_transform_continuous --- ctgan/data_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index c1e136b5..b78b2aca 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -189,7 +189,9 @@ def transform(self, raw_data): def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st): gm = column_transform_info.transform - data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())) + cols_names = list(gm.get_output_sdtypes()) + n_cols = len(cols_names) + data = pd.DataFrame(column_data[:, :n_cols], columns=cols_names) data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1) if sigmas is not None: selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])