Skip to content

Commit a798893

Browse files
committed
Merged Main back into geomloss modifications that allow slicing along columns.
Merge remote-tracking branch 'origin/master' into geomloss_update
2 parents 5ee1a4d + 9043960 commit a798893

File tree

5 files changed

+43
-8
lines changed

5 files changed

+43
-8
lines changed

.github/workflows/build_doc.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@ jobs:
1515
steps:
1616
- uses: actions/checkout@v4
1717
# Standard drop-in approach that should work for most people.
18-
18+
- name: Free Disk Space (Ubuntu)
19+
uses: insightsengineering/disk-space-reclaimer@v1
20+
with:
21+
android: true
22+
dotnet: true
1923
- name: Set up Python 3.10
2024
uses: actions/setup-python@v5
2125
with:
2226
python-version: "3.10"
27+
cache: 'pip'
2328

2429
- name: Get Python running
2530
run: |

.github/workflows/build_tests.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,17 @@ jobs:
6464
python-version: ["3.10", "3.11", "3.12", "3.13"]
6565

6666
steps:
67+
- name: Free Disk Space (Ubuntu)
68+
uses: insightsengineering/disk-space-reclaimer@v1
69+
with:
70+
android: true
71+
dotnet: true
6772
- uses: actions/checkout@v4
6873
- name: Set up Python ${{ matrix.python-version }}
6974
uses: actions/setup-python@v5
7075
with:
7176
python-version: ${{ matrix.python-version }}
77+
cache: 'pip'
7278
- name: Install POT
7379
run: |
7480
pip install -e .
@@ -93,6 +99,7 @@ jobs:
9399
uses: actions/setup-python@v5
94100
with:
95101
python-version: "3.13"
102+
cache: 'pip'
96103
- name: Install dependencies
97104
run: |
98105
python -m pip install --upgrade pip setuptools
@@ -121,6 +128,7 @@ jobs:
121128
uses: actions/setup-python@v5
122129
with:
123130
python-version: ${{ matrix.python-version }}
131+
cache: 'pip'
124132
- name: Install POT
125133
run: |
126134
pip install -e .
@@ -148,6 +156,7 @@ jobs:
148156
uses: actions/setup-python@v5
149157
with:
150158
python-version: ${{ matrix.python-version }}
159+
cache: 'pip'
151160
- name: RC.exe
152161
run: |
153162
function Invoke-VSDevEnvironment {

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
77
#### New features
88
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
99
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)
10+
- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j]
1011

1112
#### Closed issues
1213
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
1314
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
1415
- Add test for build from source (PR #772, Issue #764)
16+
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)
1517

1618
## 0.9.6.post1
1719

ot/batch/_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,9 @@ def solve_batch(
310310
B, n, m = M.shape
311311

312312
if a is None:
313-
a = nx.ones((B, n)) / n
313+
a = nx.ones((B, n), type_as=M) / n
314314
if b is None:
315-
b = nx.ones((B, m)) / m
315+
b = nx.ones((B, m), type_as=M) / m
316316

317317
if solver == "log_sinkhorn":
318318
K = -M / reg

ot/bregman/_geomloss.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,32 @@ def get_sinkhorn_geomloss_lazytensor(
5454
shape = (X_a.shape[0], X_b.shape[0])
5555

5656
def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
57+
X_a_i = X_a[i]
58+
X_b_j = X_b[j]
59+
60+
if X_a_i.ndim == 1:
61+
X_a_i = X_a_i[None, :]
62+
if X_b_j.ndim == 1:
63+
X_b_j = X_b_j[None, :]
64+
5765
if metric == "sqeuclidean":
58-
C = dist(X_a[i], X_b[j], metric=metric) / 2
66+
C = dist(X_a_i, X_b_j, metric=metric) / 2
5967
else:
60-
C = dist(X_a[i], X_b[j], metric=metric)
61-
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (
62-
a[i, None] * b[None, j]
63-
)
68+
C = dist(X_a_i, X_b_j, metric=metric)
69+
70+
# Robust broadcasting using nx backend (handles both numpy and torch)
71+
# For scalars, slice to keep 1D; for arrays, index directly
72+
f_i = f[i : i + 1] if isinstance(i, int) else f[i]
73+
g_j = g[j : j + 1] if isinstance(j, int) else g[j]
74+
a_i = a[i : i + 1] if isinstance(i, int) else a[i]
75+
b_j = b[j : j + 1] if isinstance(j, int) else b[j]
76+
77+
f_i = nx.reshape(f_i, (-1, 1))
78+
g_j = nx.reshape(g_j, (1, -1))
79+
a_i = nx.reshape(a_i, (-1, 1))
80+
b_j = nx.reshape(b_j, (1, -1))
81+
82+
return nx.squeeze(nx.exp((f_i + g_j - C) / (blur**2)) * a_i * b_j)
6483

6584
T = LazyTensor(
6685
shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur

0 commit comments

Comments
 (0)