2-Wasserstein distance calculation

Background

The 2-Wasserstein distance \(W\) is a metric to describe the distance between two distributions, representing e.g. two different conditions \(A\) and \(B\).

For continuous distributions, it is given by

\[W := W(F_A, F_B) = \bigg( \int_0^1 \big|F_A^{-1}(u) - F_B^{-1}(u) \big|^2 du \bigg)^\frac{1}{2},\]

where \(F_A\) and \(F_B\) are the corresponding cumulative distribution functions (CDFs) and \(F_A^{-1}\) and \(F_B^{-1}\) the respective quantile functions.

We specifically consider the squared 2-Wasserstein distance \(d := W^2\) which offers the following decomposition into location, size, and shape terms: \[d := d(F_A, F_B) = \int_0^1 \big|F^{-1}(u) - F^{-1}(u) \big|^2 du = \underbrace{\big(\mu_A - \mu_B\big)^2}_{\text{location}} + \underbrace{\big(\sigma_A - \sigma_B\big)^2}_{\text{size}} + \underbrace{2\sigma_A \sigma_B \big(1 - \rho^{A,B}\big)}_{\text{shape}},\]

where \(\mu_A\) and \(\mu_B\) are the respective means, \(\sigma_A\) and \(\sigma_B\) are the respective standard deviations, and \(\rho^{A,B}\) is the Pearson correlation of the points in the quantile-quantile plot of \(F_A\) and \(F_B\).

Usage in two-sample setting

In case the distributions \(F_A\) and \(F_B\) are not explicitly given and information is only availbale in the form of samples from \(F_A\) and \(F_B\), respectively, we use the corresponding empirical CDFs \(\hat{F}_A\) and \(\hat{F}_B\). Then, the 2-Wasserstein distance is computed by

\[d(\hat{F}_A, \hat{F}_B) \approx \frac{1}{K} \sum_{k=1}^K \big(Q_A^{\alpha_k} - Q_B^{\alpha_k} \big) \approx \big(\hat{\mu}_A - \hat{\mu}_B\big)^2 + \big(\hat{\sigma}_A - \hat{\sigma}_B\big)^2 + 2\hat{\sigma}_A \hat{\sigma}_B \big(1 - \hat{\rho}^{A,B}\big).\]

Here, \(Q_A\) and \(Q_B\) denote equidistant quantiles of \(F_A\) and \(F_B\), respectively, at the levels \(\alpha_k := \frac{k-0.5}{K}, k = 1, \dots , K\), where we use \(K:=1000\) in our implementation. Moreover, \(\hat{\mu}_A, \hat{\mu}_B, \hat{\sigma}_A, \hat{\sigma}_B\) and \(\hat{\rho}_{A,B}\) denote the empirical versions of the corresponding quantities.

Three implementations

The waddR package offers three functions to compute the 2-Wasserstein distance in two-sample settings.

We will use samples from normal distributions to illustrate them.

library(waddR)

set.seed(24)
x <- rnorm(100,mean=0,sd=1)
y <- rnorm(100,mean=2,sd=1)

The first function, wasserstein_metric, offers a faster reimplementation in C++ of the wasserstein1d function from the R package transport, which is able to compute general \(p\)-Wasserstein distances. For \(p=2\), we obtain the 2-Wasserstein distance \(W\).

wasserstein_metric(x,y,p=2)
#> [1] 2.044457

The corresponding value of the squared 2-Wasserstein distance \(d\) is then computed as:

wasserstein_metric(x,y,p=2)^2
#> [1] 4.179803

The second function, squared_wass_approx, computes the squared 2-Wasserstein distance by calculating the mean squared difference of the equidistant quantiles (first approximation in the previous formula). This function is currently used to compute the 2-Wasserstein distances in the testing procedures.

squared_wass_approx(x,y)
#> [1] 4.179803

The third function, squared_wass_decomp, approximates the squared 2-Wasserstein distance using the location, size, and shape terms from the above decomposition (second apporximation in the previous formula). It also returns the respective decomposition values.

squared_wass_decomp(x,y)
#> $distance
#> [1] 4.180458
#> 
#> $location
#> [1] 4.114983
#> 
#> $size
#> [1] 0.002307
#> 
#> $shape
#> [1] 0.06316766

In the considered example, the decomposition results suggest that the two distributions differ with respect to location (mean), but not in terms of size and shape, thus confirming the underlying normal model.

See also

Session Info

sessionInfo()
#> R version 4.4.0 beta (2024-04-15 r86425)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 22.04.4 LTS
#> 
#> Matrix products: default
#> BLAS:   /home/biocbuild/bbs-3.19-bioc/R/lib/libRblas.so 
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: America/New_York
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] waddR_1.18.0
#> 
#> loaded via a namespace (and not attached):
#>  [1] SummarizedExperiment_1.34.0 xfun_0.43                  
#>  [3] bslib_0.7.0                 Biobase_2.64.0             
#>  [5] lattice_0.22-6              vctrs_0.6.5                
#>  [7] tools_4.4.0                 generics_0.1.3             
#>  [9] stats4_4.4.0                curl_5.2.1                 
#> [11] parallel_4.4.0              tibble_3.2.1               
#> [13] fansi_1.0.6                 RSQLite_2.3.6              
#> [15] blob_1.2.4                  pkgconfig_2.0.3            
#> [17] Matrix_1.7-0                arm_1.14-4                 
#> [19] dbplyr_2.5.0                S4Vectors_0.42.0           
#> [21] lifecycle_1.0.4             GenomeInfoDbData_1.2.12    
#> [23] compiler_4.4.0              codetools_0.2-20           
#> [25] eva_0.2.6                   GenomeInfoDb_1.40.0        
#> [27] htmltools_0.5.8.1           sass_0.4.9                 
#> [29] yaml_2.3.8                  nloptr_2.0.3               
#> [31] pillar_1.9.0                crayon_1.5.2               
#> [33] jquerylib_0.1.4             MASS_7.3-60.2              
#> [35] BiocParallel_1.38.0         SingleCellExperiment_1.26.0
#> [37] DelayedArray_0.30.0         cachem_1.0.8               
#> [39] boot_1.3-30                 abind_1.4-5                
#> [41] nlme_3.1-164                tidyselect_1.2.1           
#> [43] digest_0.6.35               purrr_1.0.2                
#> [45] dplyr_1.1.4                 splines_4.4.0              
#> [47] fastmap_1.1.1               grid_4.4.0                 
#> [49] SparseArray_1.4.0           cli_3.6.2                  
#> [51] magrittr_2.0.3              S4Arrays_1.4.0             
#> [53] utf8_1.2.4                  withr_3.0.0                
#> [55] filelock_1.0.3              UCSC.utils_1.0.0           
#> [57] bit64_4.0.5                 rmarkdown_2.26             
#> [59] XVector_0.44.0              httr_1.4.7                 
#> [61] matrixStats_1.3.0           lme4_1.1-35.3              
#> [63] bit_4.0.5                   coda_0.19-4.1              
#> [65] memoise_2.0.1               evaluate_0.23              
#> [67] knitr_1.46                  GenomicRanges_1.56.0       
#> [69] IRanges_2.38.0              BiocFileCache_2.12.0       
#> [71] rlang_1.1.3                 Rcpp_1.0.12                
#> [73] glue_1.7.0                  DBI_1.2.2                  
#> [75] BiocGenerics_0.50.0         minqa_1.2.6                
#> [77] jsonlite_1.8.8              R6_2.5.1                   
#> [79] MatrixGenerics_1.16.0       zlibbioc_1.50.0