Skip to content

Conversation

atong01
Copy link
Owner

@atong01 atong01 commented Jul 22, 2025

Summary by Sourcery

Introduce a high-dimensional single-cell flow-matching example and utility module, and update README with pepy download badges

New Features:

  • Add train_single_cell_high_dimension.py example script for training conditional flow matching on high-dimensional single-cell data
  • Add utils_single_cell.py module with functions to load AnnData datasets, split timepoint data, and build combined train/val/test dataloaders

Documentation:

  • Add pepy download badges to README

Copy link

sourcery-ai bot commented Jul 22, 2025

Reviewer's Guide

This PR adds a high-dimensional single-cell training example with supporting utilities and updates the README to include download badges for the torchcfm package.

Class diagram for new single-cell training utilities

classDiagram
    class adata_dataset {
        +adata_dataset(path, embed_name="X_pca", label_name="sample_labels", max_dim=100)
    }
    class split {
        +split(timepoint_data)
    }
    class combined_loader {
        +combined_loader(split_timepoint_data, index, shuffle=False, load_full=False)
    }
    class train_dataloader {
        +train_dataloader(split_timepoint_data)
    }
    class val_dataloader {
        +val_dataloader(split_timepoint_data)
    }
    class test_dataloader {
        +test_dataloader(split_timepoint_data)
    }
    adata_dataset --> split
    split --> combined_loader
    combined_loader --> train_dataloader
    combined_loader --> val_dataloader
    combined_loader --> test_dataloader
Loading

File-Level Changes

Change Details Files
Update README with download badges
  • Insert total downloads badge
  • Insert monthly downloads badge
README.md
Introduce single-cell high-dimensional training script
  • Import required libraries and modules
  • Load and preprocess single-cell data
  • Define and execute CFM training loop
examples/single_cell/train_single_cell_high_dimension.py
Add utility functions for single-cell data handling
  • Implement adata dataset loading function
  • Implement split, train/val/test split logic
  • Define combined DataLoader for timepoint batches
examples/single_cell/utils_single_cell.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @atong01 - I've reviewed your changes - here's some feedback:

  • The example script hardcodes dataset paths, hyperparameters, and training settings; consider adding an argument parser or config file to make these parameters configurable.
  • The data loading pipeline mixes 'adata' and the scaled 'data' variables inconsistently (e.g., timepoint_data uses raw adata instead of the scaled data), so unify preprocessing into a single data object to avoid mismatches.
  • It would be useful to add basic training logging (e.g., printing or plotting loss via tqdm) and model checkpointing to improve observability and reproducibility.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The example script hardcodes dataset paths, hyperparameters, and training settings; consider adding an argument parser or config file to make these parameters configurable.
- The data loading pipeline mixes 'adata' and the scaled 'data' variables inconsistently (e.g., timepoint_data uses raw adata instead of the scaled data), so unify preprocessing into a single data object to avoid mismatches.
- It would be useful to add basic training logging (e.g., printing or plotting loss via tqdm) and model checkpointing to improve observability and reproducibility.

## Individual Comments

### Comment 1
<location> `examples/single_cell/train_single_cell_high_dimension.py:23` </location>
<code_context>
+
+data, labels, ulabels = adata_dataset("./ebdata_v2.h5ad")
+
+if max_dim==1000:
+    sc.pp.highly_variable_genes(adata, n_top_genes=max_dim)
+    adata = adata.X[:, adata.var["highly_variable"]].toarray()
+
+# Standardize coordinates
</code_context>

<issue_to_address>
Potential confusion between 'adata' as AnnData and as a numpy array.

'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.
</issue_to_address>

### Comment 2
<location> `examples/single_cell/train_single_cell_high_dimension.py:59` </location>
<code_context>
+n_epochs = 2000
+for _ in range(n_epochs):
+    for X in train_dataloader:
+        t_snapshot = np.random.randint(0,4)
+        t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1])
+        ot_cfm_optimizer.zero_grad()
+        vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))
</code_context>

<issue_to_address>
Hardcoded timepoint range may not generalize.

Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data.
</issue_to_address>

### Comment 3
<location> `examples/single_cell/utils_single_cell.py:16` </location>
<code_context>
+    ulabels = labels.cat.categories
+    return adata.obsm[embed_name][:, :max_dim], labels, ulabels
+
+def split(timepoint_data):
+    """split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
+    train_val_test_split = [0.8, 0.1, 0.1]
+    if isinstance(train_val_test_split, int):
+        split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
+        return split_timepoint_data
</code_context>

<issue_to_address>
Unreachable code path for integer split ratios.

Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.
</issue_to_address>

<suggested_fix>
<<<<<<< SEARCH
def split(timepoint_data):
    """split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
    train_val_test_split = [0.8, 0.1, 0.1]
    if isinstance(train_val_test_split, int):
        split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
        return split_timepoint_data
    splitter = partial(
        random_split,
        lengths=train_val_test_split,
        generator=torch.Generator().manual_seed(42),
    )
    split_timepoint_data = list(map(splitter, timepoint_data))
    return split_timepoint_data
=======
def split(timepoint_data):
    """split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
    train_val_test_split = [0.8, 0.1, 0.1]
    splitter = partial(
        random_split,
        lengths=train_val_test_split,
        generator=torch.Generator().manual_seed(42),
    )
    split_timepoint_data = list(map(splitter, timepoint_data))
    return split_timepoint_data
>>>>>>> REPLACE

</suggested_fix>

### Comment 4
<location> `examples/single_cell/utils_single_cell.py:30` </location>
<code_context>
+    split_timepoint_data = list(map(splitter, timepoint_data))
+    return split_timepoint_data
+
+def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False):
+    tp_dataloaders = [
+        DataLoader(
+            dataset=datasets[index],
+            batch_size=128,
+            shuffle=shuffle,
+            drop_last=True,
+        )
+        for datasets in split_timepoint_data
+    ]
+    return CombinedLoader(tp_dataloaders, mode="min_size")
+
+def train_dataloader(split_timepoint_data):
</code_context>

<issue_to_address>
Unused 'load_full' parameter in 'combined_loader'.

Consider removing the unused 'load_full' parameter if it serves no purpose.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +23 to +25
if max_dim==1000:
sc.pp.highly_variable_genes(adata, n_top_genes=max_dim)
adata = adata.X[:, adata.var["highly_variable"]].toarray()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Potential confusion between 'adata' as AnnData and as a numpy array.

'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.

Comment on lines +59 to +60
t_snapshot = np.random.randint(0,4)
t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Hardcoded timepoint range may not generalize.

Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data.

Comment on lines +16 to +28
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
if isinstance(train_val_test_split, int):
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Unreachable code path for integer split ratios.

Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.

Suggested change
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
if isinstance(train_val_test_split, int):
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data

Comment on lines +30 to +39
def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False):
tp_dataloaders = [
DataLoader(
dataset=datasets[index],
batch_size=128,
shuffle=shuffle,
drop_last=True,
)
for datasets in split_timepoint_data
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Unused 'load_full' parameter in 'combined_loader'.

Consider removing the unused 'load_full' parameter if it serves no purpose.

Comment on lines +20 to +27
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Inline variable that is immediately returned [×2] (inline-immediately-returned-variable)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants