Skip to content

Commit 1c479f2

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (2862).
1 parent f815a23 commit 1c479f2

File tree

280 files changed

+733273
-732997
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

280 files changed

+733273
-732997
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Translation Invariant Sinkhorn for Unbalanced Optimal Transport\n\nThis examples illustrates the better convergence of the translation\ninvariance Sinkhorn algorithm proposed in [73] compared to the classical\nSinkhorn algorithm.\n\n[73] S\u00e9journ\u00e9, T., Vialard, F. X., & Peyr\u00e9, G. (2022).\nFaster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.\nIn International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Cl\u00e9ment Bonet <clement.bonet@ensae.fr>\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Setting parameters\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"n_iter = 50 # nb iters\nn = 40 # nb samples\n\nnum_iter_max = 100\nn_noise = 10\n\nreg = 0.005\nreg_m_kl = 0.05\n\nmu_s = np.array([-1, -1])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -.8], [-.8, 1]])"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"err_sinkhorn_uot = np.empty((n_iter, num_iter_max))\nerr_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))\n\n\nfor seed in range(n_iter):\n np.random.seed(seed)\n xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\n xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\n xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)\n xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)\n\n n = n + n_noise\n\n a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n # loss matrix\n M = ot.dist(xs, xt)\n M /= M.max()\n\n entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type=\"kl\", log=True, numItermax=num_iter_max, stopThr=0)\n entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type=\"kl\",\n method=\"sinkhorn_translation_invariant\", log=True,\n numItermax=num_iter_max, stopThr=0)\n\n err_sinkhorn_uot[seed] = log_uot[\"err\"]\n err_sinkhorn_uot_ti[seed] = log_uot_ti[\"err\"]"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Plot the results\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)\nstd_sinkh = np.std(err_sinkhorn_uot, axis=0)\n\nmean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)\nstd_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)\n\nabsc = list(range(num_iter_max))\n\npl.plot(absc, mean_sinkh, label=\"Sinkhorn\")\npl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)\n\npl.plot(absc, mean_sinkh_ti, label=\"Translation Invariant Sinkhorn\")\npl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5)\n\npl.yscale(\"log\")\npl.legend()\npl.xlabel(\"Number of Iterations\")\npl.ylabel(r\"$\\|u-v\\|_\\infty$\")\npl.grid(True)\npl.show()"
73+
]
74+
}
75+
],
76+
"metadata": {
77+
"kernelspec": {
78+
"display_name": "Python 3",
79+
"language": "python",
80+
"name": "python3"
81+
},
82+
"language_info": {
83+
"codemirror_mode": {
84+
"name": "ipython",
85+
"version": 3
86+
},
87+
"file_extension": ".py",
88+
"mimetype": "text/x-python",
89+
"name": "python",
90+
"nbconvert_exporter": "python",
91+
"pygments_lexer": "ipython3",
92+
"version": "3.10.15"
93+
}
94+
},
95+
"nbformat": 4,
96+
"nbformat_minor": 0
97+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===============================================================
4+
Translation Invariant Sinkhorn for Unbalanced Optimal Transport
5+
===============================================================
6+
7+
This examples illustrates the better convergence of the translation
8+
invariance Sinkhorn algorithm proposed in [73] compared to the classical
9+
Sinkhorn algorithm.
10+
11+
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
12+
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
13+
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
14+
15+
"""
16+
17+
# Author: Clément Bonet <clement.bonet@ensae.fr>
18+
# License: MIT License
19+
20+
import numpy as np
21+
import matplotlib.pylab as pl
22+
import ot
23+
24+
##############################################################################
25+
# Setting parameters
26+
# -------------
27+
28+
# %% parameters
29+
30+
n_iter = 50 # nb iters
31+
n = 40 # nb samples
32+
33+
num_iter_max = 100
34+
n_noise = 10
35+
36+
reg = 0.005
37+
reg_m_kl = 0.05
38+
39+
mu_s = np.array([-1, -1])
40+
cov_s = np.array([[1, 0], [0, 1]])
41+
42+
mu_t = np.array([4, 4])
43+
cov_t = np.array([[1, -.8], [-.8, 1]])
44+
45+
46+
##############################################################################
47+
# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn
48+
# -----------
49+
50+
err_sinkhorn_uot = np.empty((n_iter, num_iter_max))
51+
err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))
52+
53+
54+
for seed in range(n_iter):
55+
np.random.seed(seed)
56+
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
57+
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
58+
59+
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
60+
xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
61+
62+
n = n + n_noise
63+
64+
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
65+
66+
# loss matrix
67+
M = ot.dist(xs, xt)
68+
M /= M.max()
69+
70+
entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0)
71+
entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl",
72+
method="sinkhorn_translation_invariant", log=True,
73+
numItermax=num_iter_max, stopThr=0)
74+
75+
err_sinkhorn_uot[seed] = log_uot["err"]
76+
err_sinkhorn_uot_ti[seed] = log_uot_ti["err"]
77+
78+
##############################################################################
79+
# Plot the results
80+
# ----------------
81+
82+
mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)
83+
std_sinkh = np.std(err_sinkhorn_uot, axis=0)
84+
85+
mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)
86+
std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)
87+
88+
absc = list(range(num_iter_max))
89+
90+
pl.plot(absc, mean_sinkh, label="Sinkhorn")
91+
pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)
92+
93+
pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn")
94+
pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5)
95+
96+
pl.yscale("log")
97+
pl.legend()
98+
pl.xlabel("Number of Iterations")
99+
pl.ylabel(r"$\|u-v\|_\infty$")
100+
pl.grid(True)
101+
pl.show()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
766 Bytes
246 Bytes
590 Bytes
218 Bytes
222 Bytes
-349 Bytes
-320 Bytes
-113 Bytes
-205 Bytes
101 Bytes
-151 Bytes

master/_modules/ot/backend.html

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,9 +2267,7 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
22672267
<span class="k">if</span> <span class="n">a</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
22682268
<span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">side</span><span class="p">)</span>
22692269
<span class="k">else</span><span class="p">:</span>
2270-
<span class="c1"># this is a not very efficient way to make jax numpy</span>
2271-
<span class="c1"># searchsorted work on 2d arrays</span>
2272-
<span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">jnp</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:],</span> <span class="n">v</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:],</span> <span class="n">side</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])])</span></div>
2270+
<span class="k">return</span> <span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">b</span><span class="p">,</span> <span class="n">u</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">u</span><span class="p">,</span> <span class="n">side</span><span class="p">))(</span><span class="n">a</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></div>
22732271

22742272

22752273
<div class="viewcode-block" id="JaxBackend.flip">

0 commit comments

Comments
 (0)