An alignment based on transport weights sets the weight between topics k and k' according to an optimal transport problem with (1) costs set by the distance (specifically, Jensen-Shannon Divergence) between \(\beta_{k}\) and \(\beta_{k^\prime}\) and (2) masses defined by the total topic mixed memberships \(\sum_{i}\gamma_{ik}\) and \(\sum_{i}\gamma_{ik^\prime}\). If topics have similar mixed membership weight and similar topic \(\beta\), then they will be given high transport alignment weight.

transport_weights(gammas, betas, reg = 0.1, ...)

Arguments

gammas

(required) A list of length two, containing the mixed membership matrices (a matrix of dimension n-samples by k-topics) to compare. The number of columns may be different, but the number of samples must be equal.

betas

(required). A list of length two, containing the topic matrices (a matrix of dimension k-topics by d-dimensions).) The number of rows may be different, but the number of columns must remain fixed.

reg

(optional) How much regularization to use in the Sinkhorn optimal transport algorithm? Defaults to 0.1.

...

(optional) Other keyword arguments. Not used here, but included for consistency with other weight functions.

Value

products A data.frame giving the product similarity of each pair of topics across the two input matrices.

Examples

library(purrr)
data <- rmultinom(10, 20, rep(0.1, 20))
lda_params <- setNames(map(1:5, ~ list(k = .)), 1:5)
lda_models <- run_lda_models(data, lda_params)
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
gammas <- list(lda_models[[3]]$gamma, lda_models[[5]]$gamma)
betas <- list(lda_models[[3]]$beta, lda_models[[5]]$beta)
transport_weights(gammas, betas)
#> # A tibble: 15 × 3
#>    k     k_next  weight
#>    <chr> <chr>    <dbl>
#>  1 1     1      3.19   
#>  2 1     2      0.227  
#>  3 1     3      0.392  
#>  4 1     4      2.86   
#>  5 1     5      0.00343
#>  6 2     1      0.416  
#>  7 2     2      1.12   
#>  8 2     3      1.42   
#>  9 2     4      0.304  
#> 10 2     5      3.40   
#> 11 3     1      0.398  
#> 12 3     2      2.65   
#> 13 3     3      2.18   
#> 14 3     4      0.837  
#> 15 3     5      0.597