{ "cells": [ { "cell_type": "markdown", "id": "2f9f2f21-569b-4429-8f5c-081e6089bc78", "metadata": {}, "source": [ "# Building up tools to compute an approximation of the 2-Wasserstein distance" ] }, { "cell_type": "markdown", "id": "1b42cdee-d0de-4562-8669-06916312af00", "metadata": {}, "source": [ "> In this section we create a function to compute an approximation of the 2-Wasserstein distance between two univariate data sets" ] }, { "cell_type": "code", "execution_count": 1, "id": "fb9a10f4-12da-4456-a50c-ed52bf6b2e04", "metadata": {}, "outputs": [], "source": [ "from dolfin import *\n", "import numpy as np\n", "import ot\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd\n", "from statFEM_analysis.oneDim import mean_assembler, kernMat, cov_assembler, sample_gp\n", "from scipy.stats import linregress\n", "from scipy import integrate\n", "from scipy.linalg import sqrtm\n", "from tqdm.notebook import tqdm\n", "import sympy; sympy.init_printing()\n", "# code for displaying matrices nicely\n", "def display_matrix(m):\n", " display(sympy.Matrix(m))" ] }, { "cell_type": "markdown", "id": "d30ae0ca-9076-4be4-be54-63eaf72656c0", "metadata": {}, "source": [ "## Computing the 2-Wasserstein distance between two data-sets\n", "\n", "We start by creating a function [wass()](statFEM_analysis.rst#statFEM_analysis.maxDist.wass) to estimate the 2-Wasserstein distance between two data-sets `a` and `b`, using the Python package [POT](https://github.com/PythonOT/POT)." ] }, { "cell_type": "code", "execution_count": 2, "id": "17b61beb-dd7d-4ec6-ba3a-d2326075eaf1", "metadata": {}, "outputs": [], "source": [ "from statFEM_analysis.maxDist import wass" ] }, { "cell_type": "markdown", "id": "781816a8-b1fd-4a1c-a56f-f5b9f1f25953", "metadata": {}, "source": [ "`wass` takes in the two datasets `a` and `b` as well as an argument `n_bin` which controls how many bins are used to create the histograms for the datasets." ] }, { "cell_type": "markdown", "id": "1e6cb9d6-dd22-48ef-acd5-90edf9645c7a", "metadata": {}, "source": [ "Let's test this function out. First we make sure it gives $\\operatorname{wass}(a,a) = 0$ for any dataset $a$." ] }, { "cell_type": "code", "execution_count": 3, "id": "f5409c91-138f-42ac-a0ea-c5fbeeee6aa9", "metadata": {}, "outputs": [], "source": [ "# standard normal\n", "N = 1000 # number of samples\n", "n_bins = 10 # number of bins\n", "np.random.seed(134)\n", "a = np.random.normal(size=N)\n", "assert wass(a,a,n_bins) == 0" ] }, { "cell_type": "markdown", "id": "cfd46535-f54c-4332-9187-e9323d32de47", "metadata": {}, "source": [ "We also test it on samples from 2 different Gaussians, $a\\sim\\mathcal{N}(m_0,s_0^{2})$ and $b\\sim\\mathcal{N}(m_1,s_1^{2})$. We expect, theoretically, that $\\operatorname{wass}(a,b)=\\sqrt{|m_0-m_1|^{2}+|s_0-s_1|^{2}}$." ] }, { "cell_type": "code", "execution_count": 4, "id": "1a09e1d6-5fe0-49ee-b4e7-2f11d4fdcb79", "metadata": {}, "outputs": [], "source": [ "# set up means and standard deviations\n", "m_0 = 7\n", "m_1 = 58\n", "s_0 = 1.63\n", "s_1 = 0.7\n", "\n", "# draw the samples\n", "N = 1000\n", "#####################################\n", "n_bins = 50 # number of bins\n", "#####################################\n", "np.random.seed(2321)\n", "a = np.random.normal(loc = m_0, scale = s_0,size=N)\n", "b = np.random.normal(loc = m_1, scale = s_1,size=N)\n", "\n", "# tolerance for the comparison\n", "tol = 1e-1\n", "\n", "# compute the 2-wasserstein with our function and also the true theoretical value\n", "W = wass(a,b,n_bins)\n", "W_true = np.sqrt(np.abs(m_0-m_1)**2 + np.abs(s_0-s_1)**2)\n", "# compare\n", "assert np.abs(W - W_true) <= tol" ] }, { "cell_type": "markdown", "id": "c8a20642-7942-4e9e-9e58-63e678a1073a", "metadata": {}, "source": [ "Let's take the previous example and compute the distance for a range of different means and standard deviations." ] }, { "cell_type": "code", "execution_count": 5, "id": "e09dbc00-94f1-4bad-bc11-24c4111a8022", "metadata": {}, "outputs": [], "source": [ "# set up range for means and standard deviations\n", "n = 40\n", "m_range = np.linspace(m_0 - 2, m_0 + 2, n)\n", "s_range = np.linspace(s_0/4, 2*s_0, n)\n", "\n", "# set up arrays to hold results with our function, the theoretical results, \n", "# and theoretical results using estimated means and standard deviations\n", "W = np.zeros((n, n))\n", "W_0 = np.zeros((n, n))\n", "W_est = np.zeros((n,n))\n", "\n", "N = 10000 # number of samples\n", "################################################\n", "n_bins = 100 # number of bins\n", "################################################\n", "np.random.seed(2321)\n", "a = np.random.normal(loc = m_0, scale = s_0,size=N)\n", "m_a_est = np.mean(a)\n", "s_a_est = np.std(a)\n", "\n", "# sample for each m,s in the ranges and compute the results\n", "for i, m in enumerate(m_range):\n", " for j, s in enumerate(s_range):\n", " b = np.random.normal(loc = m, scale = s, size = N)\n", " m_est = np.mean(b)\n", " s_est = np.std(b)\n", " \n", " W[i,j] = wass(a,b,n_bins)\n", " W_0[i,j] = np.sqrt(np.abs(m - m_0)**2 + np.abs(s - s_0)**2)\n", " W_est[i,j] = np.sqrt(np.abs(m_est - m_a_est)**2 + np.abs(s_est - s_a_est)**2)" ] }, { "cell_type": "markdown", "id": "8e20cccc-a9b7-440c-adde-aec590edb288", "metadata": {}, "source": [ "Let's visualize the results:" ] }, { "cell_type": "code", "execution_count": 7, "id": "8d86ba25-6a18-4bfb-bb82-a19a07b7e818", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "M, S = np.meshgrid(m_range, s_range,indexing='ij')\n", "plt.rcParams['figure.figsize'] = (16,5)\n", "fig, axs = plt.subplots(ncols=4, gridspec_kw=dict(width_ratios=[4,4,4,0.2]))\n", "axs[0].contourf(M, S, W)\n", "axs[0].scatter([m_0],[s_0],marker='X',c='red')\n", "axs[0].set_xlabel('$m$')\n", "axs[0].set_ylabel('$s$')\n", "axs[0].set_title('POT')\n", "\n", "axs[1].contourf(M, S, W_est)\n", "axs[1].scatter([m_0],[s_0],marker='X',c='red')\n", "axs[1].set_xlabel('$m$')\n", "axs[1].set_ylabel('$s$')\n", "axs[1].set_title('Estimated truth')\n", "\n", "axs[2].contourf(M, S, W_0)\n", "axs[2].scatter([m_0],[s_0],marker='X',c='red')\n", "axs[2].set_xlabel('$m$')\n", "axs[2].set_ylabel('$s$')\n", "axs[2].set_title('True')\n", "fig.colorbar(axs[np.argmax([W.max(), W_est.max(),W_0.max()])].collections[0], cax=axs[3])\n", "plt.tight_layout()\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 5 }