Loading [MathJax]/jax/output/HTML-CSS/jax.js

2-Wasserstein Distance calculation

Background

The 2-Wasserstein distance W is a metric to describe the distance between two distributions, representing two different conditions A and B.

For continuous distributions, it is given by

W:=W(FA,FB)=(10|F1A(u)F1B(u)|2du)12,

where FA and FB are the corresponding cumulative distribution functions (CDFs) and F1A and F1B the respective quantile functions.

We specifically consider the squared 2-Wasserstein distance d:=W2 which offers the following decomposition into location, size, and shape terms: d:=d(FA,FB)=10|F1(u)F1(u)|2du=(μAμB)2location+(σAσB)2size+2σAσB(1ρA,B)shape,

where μA and μB are the respective means, σA andσB are the respective standard deviations, and ρA,B is the Pearson correlation of the points in the quantile-quantile plot of FA and FB.

Usage in two-sample setting

In case the distributions FA and FB are not explicitly given and information is only availbale in the form of samples from FA and FB, respectively, we use the corresponding empirical CDFs ˆFA and ˆFB. Then, the 2-Wasserstein distance is computed by

d(ˆFA,ˆFB)1KKk=1(QαkAQαkB)(ˆμAˆμB)2+(ˆσAˆσB)2+2ˆσAˆσB(1ˆρA,B).

Here, QA and QB denote equidistant quantiles of FA and FB, respectively, at the levels αk:=k0.5K,k=1,,K using K=1000 in our implementation. Moreover, ˆμA,ˆμB,ˆσA,ˆσB, and ˆρA,B denote the empirical versions of the corresponding quantiles from the original decomposition of d.

Three Implementations

The package waddR offers three functions to compute the 2-Wasserstein distance in two-sample settings.

We will use samples from normal distributions to illustrate them.

library(waddR)

set.seed(24)
x <- rnorm(100, mean=0, sd=1)
y <- rnorm(100, mean=2, sd=1)

The first function, wasserstein_metric offers a faster reimplementation in Cpp of the function wasserstein1d from the R package transport, which computes the original 2-Wasserstein distance W.

wasserstein_metric(x, y)
#> [1] 2.028542

The corresponding value of the squared 2-Wasserstein distance d is then:

wasserstein_metric(x, y)**2
#> [1] 4.114983

The second function, squared_wass_approx, computes the squared 2-Wasserstein distance by calculating the mean squared difference of the equidistant quantiles (first approximation in the previous formula). This function is currently used to compute the 2-Wasserstein distance in the testing procedures.

squared_wass_approx(x, y)
#> [1] 4.179803

The third function, squared_wass_decomp, approximates the squared 2-Wasserstein distance by addding the location, size, and shape terms from the above decomposition (second apporximation in the previous formula). It also returns the respective decomposition values.

squared_wass_decomp(x, y)
#> $distance
#> [1] 4.180458
#> 
#> $location
#> [1] 4.114983
#> 
#> $size
#> [1] 0.002307
#> 
#> $shape
#> [1] 0.06316766

The decomposition results reflect that in the considered example, the two distributions differ with respect to location (mean), but not in terms of size and shape, thus confirming the underlying normal model.

See Also

Session Info

sessionInfo()
#> R version 4.0.0 (2020-04-24)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Ubuntu 18.04.4 LTS
#> 
#> Matrix products: default
#> BLAS:   /home/biocbuild/bbs-3.11-bioc/R/lib/libRblas.so
#> LAPACK: /home/biocbuild/bbs-3.11-bioc/R/lib/libRlapack.so
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] waddR_1.2.0
#> 
#> loaded via a namespace (and not attached):
#>  [1] SummarizedExperiment_1.18.0 statmod_1.4.34             
#>  [3] tidyselect_1.0.0            xfun_0.13                  
#>  [5] purrr_0.3.4                 splines_4.0.0              
#>  [7] lattice_0.20-41             vctrs_0.2.4                
#>  [9] htmltools_0.4.0             stats4_4.0.0               
#> [11] BiocFileCache_1.12.0        yaml_2.2.1                 
#> [13] blob_1.2.1                  rlang_0.4.5                
#> [15] nloptr_1.2.2.1              pillar_1.4.3               
#> [17] glue_1.4.0                  DBI_1.1.0                  
#> [19] BiocParallel_1.22.0         rappdirs_0.3.1             
#> [21] SingleCellExperiment_1.10.0 BiocGenerics_0.34.0        
#> [23] bit64_0.9-7                 dbplyr_1.4.3               
#> [25] matrixStats_0.56.0          GenomeInfoDbData_1.2.3     
#> [27] lifecycle_0.2.0             stringr_1.4.0              
#> [29] zlibbioc_1.34.0             coda_0.19-3                
#> [31] memoise_1.1.0               evaluate_0.14              
#> [33] Biobase_2.48.0              knitr_1.28                 
#> [35] IRanges_2.22.0              GenomeInfoDb_1.24.0        
#> [37] parallel_4.0.0              curl_4.3                   
#> [39] Rcpp_1.0.4.6                arm_1.11-1                 
#> [41] DelayedArray_0.14.0         S4Vectors_0.26.0           
#> [43] XVector_0.28.0              abind_1.4-5                
#> [45] lme4_1.1-23                 bit_1.1-15.2               
#> [47] digest_0.6.25               stringi_1.4.6              
#> [49] dplyr_0.8.5                 GenomicRanges_1.40.0       
#> [51] grid_4.0.0                  tools_4.0.0                
#> [53] bitops_1.0-6                magrittr_1.5               
#> [55] RCurl_1.98-1.2              tibble_3.0.1               
#> [57] RSQLite_2.2.0               crayon_1.3.4               
#> [59] pkgconfig_2.0.3             ellipsis_0.3.0             
#> [61] MASS_7.3-51.6               Matrix_1.2-18              
#> [63] minqa_1.2.4                 assertthat_0.2.1           
#> [65] rmarkdown_2.1               httr_1.4.1                 
#> [67] boot_1.3-25                 R6_2.4.1                   
#> [69] nlme_3.1-147                compiler_4.0.0