-
Notifications
You must be signed in to change notification settings - Fork 160
Single cell hd #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Single cell hd #169
Conversation
Reviewer's GuideThis PR adds a high-dimensional single-cell training example with supporting utilities and updates the README to include download badges for the torchcfm package. Class diagram for new single-cell training utilitiesclassDiagram
class adata_dataset {
+adata_dataset(path, embed_name="X_pca", label_name="sample_labels", max_dim=100)
}
class split {
+split(timepoint_data)
}
class combined_loader {
+combined_loader(split_timepoint_data, index, shuffle=False, load_full=False)
}
class train_dataloader {
+train_dataloader(split_timepoint_data)
}
class val_dataloader {
+val_dataloader(split_timepoint_data)
}
class test_dataloader {
+test_dataloader(split_timepoint_data)
}
adata_dataset --> split
split --> combined_loader
combined_loader --> train_dataloader
combined_loader --> val_dataloader
combined_loader --> test_dataloader
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @atong01 - I've reviewed your changes - here's some feedback:
- The example script hardcodes dataset paths, hyperparameters, and training settings; consider adding an argument parser or config file to make these parameters configurable.
- The data loading pipeline mixes 'adata' and the scaled 'data' variables inconsistently (e.g., timepoint_data uses raw adata instead of the scaled data), so unify preprocessing into a single data object to avoid mismatches.
- It would be useful to add basic training logging (e.g., printing or plotting loss via tqdm) and model checkpointing to improve observability and reproducibility.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The example script hardcodes dataset paths, hyperparameters, and training settings; consider adding an argument parser or config file to make these parameters configurable.
- The data loading pipeline mixes 'adata' and the scaled 'data' variables inconsistently (e.g., timepoint_data uses raw adata instead of the scaled data), so unify preprocessing into a single data object to avoid mismatches.
- It would be useful to add basic training logging (e.g., printing or plotting loss via tqdm) and model checkpointing to improve observability and reproducibility.
## Individual Comments
### Comment 1
<location> `examples/single_cell/train_single_cell_high_dimension.py:23` </location>
<code_context>
+
+data, labels, ulabels = adata_dataset("./ebdata_v2.h5ad")
+
+if max_dim==1000:
+ sc.pp.highly_variable_genes(adata, n_top_genes=max_dim)
+ adata = adata.X[:, adata.var["highly_variable"]].toarray()
+
+# Standardize coordinates
</code_context>
<issue_to_address>
Potential confusion between 'adata' as AnnData and as a numpy array.
'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.
</issue_to_address>
### Comment 2
<location> `examples/single_cell/train_single_cell_high_dimension.py:59` </location>
<code_context>
+n_epochs = 2000
+for _ in range(n_epochs):
+ for X in train_dataloader:
+ t_snapshot = np.random.randint(0,4)
+ t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1])
+ ot_cfm_optimizer.zero_grad()
+ vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))
</code_context>
<issue_to_address>
Hardcoded timepoint range may not generalize.
Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data.
</issue_to_address>
### Comment 3
<location> `examples/single_cell/utils_single_cell.py:16` </location>
<code_context>
+ ulabels = labels.cat.categories
+ return adata.obsm[embed_name][:, :max_dim], labels, ulabels
+
+def split(timepoint_data):
+ """split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
+ train_val_test_split = [0.8, 0.1, 0.1]
+ if isinstance(train_val_test_split, int):
+ split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
+ return split_timepoint_data
</code_context>
<issue_to_address>
Unreachable code path for integer split ratios.
Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.
</issue_to_address>
<suggested_fix>
<<<<<<< SEARCH
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
if isinstance(train_val_test_split, int):
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data
=======
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data
>>>>>>> REPLACE
</suggested_fix>
### Comment 4
<location> `examples/single_cell/utils_single_cell.py:30` </location>
<code_context>
+ split_timepoint_data = list(map(splitter, timepoint_data))
+ return split_timepoint_data
+
+def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False):
+ tp_dataloaders = [
+ DataLoader(
+ dataset=datasets[index],
+ batch_size=128,
+ shuffle=shuffle,
+ drop_last=True,
+ )
+ for datasets in split_timepoint_data
+ ]
+ return CombinedLoader(tp_dataloaders, mode="min_size")
+
+def train_dataloader(split_timepoint_data):
</code_context>
<issue_to_address>
Unused 'load_full' parameter in 'combined_loader'.
Consider removing the unused 'load_full' parameter if it serves no purpose.
</issue_to_address>
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
if max_dim==1000: | ||
sc.pp.highly_variable_genes(adata, n_top_genes=max_dim) | ||
adata = adata.X[:, adata.var["highly_variable"]].toarray() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Potential confusion between 'adata' as AnnData and as a numpy array.
'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.
t_snapshot = np.random.randint(0,4) | ||
t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Hardcoded timepoint range may not generalize.
Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data.
def split(timepoint_data): | ||
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" | ||
train_val_test_split = [0.8, 0.1, 0.1] | ||
if isinstance(train_val_test_split, int): | ||
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data)) | ||
return split_timepoint_data | ||
splitter = partial( | ||
random_split, | ||
lengths=train_val_test_split, | ||
generator=torch.Generator().manual_seed(42), | ||
) | ||
split_timepoint_data = list(map(splitter, timepoint_data)) | ||
return split_timepoint_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Unreachable code path for integer split ratios.
Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.
def split(timepoint_data): | |
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" | |
train_val_test_split = [0.8, 0.1, 0.1] | |
if isinstance(train_val_test_split, int): | |
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data)) | |
return split_timepoint_data | |
splitter = partial( | |
random_split, | |
lengths=train_val_test_split, | |
generator=torch.Generator().manual_seed(42), | |
) | |
split_timepoint_data = list(map(splitter, timepoint_data)) | |
return split_timepoint_data | |
def split(timepoint_data): | |
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" | |
train_val_test_split = [0.8, 0.1, 0.1] | |
splitter = partial( | |
random_split, | |
lengths=train_val_test_split, | |
generator=torch.Generator().manual_seed(42), | |
) | |
split_timepoint_data = list(map(splitter, timepoint_data)) | |
return split_timepoint_data |
def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False): | ||
tp_dataloaders = [ | ||
DataLoader( | ||
dataset=datasets[index], | ||
batch_size=128, | ||
shuffle=shuffle, | ||
drop_last=True, | ||
) | ||
for datasets in split_timepoint_data | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: Unused 'load_full' parameter in 'combined_loader'.
Consider removing the unused 'load_full' parameter if it serves no purpose.
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data)) | ||
return split_timepoint_data | ||
splitter = partial( | ||
random_split, | ||
lengths=train_val_test_split, | ||
generator=torch.Generator().manual_seed(42), | ||
) | ||
split_timepoint_data = list(map(splitter, timepoint_data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Inline variable that is immediately returned [×2] (inline-immediately-returned-variable
)
Summary by Sourcery
Introduce a high-dimensional single-cell flow-matching example and utility module, and update README with pepy download badges
New Features:
Documentation: