transport_weights.Rd
An alignment based on transport weights sets the weight between topics k and k' according to an optimal transport problem with (1) costs set by the distance (specifically, Jensen-Shannon Divergence) between \(\beta_{k}\) and \(\beta_{k^\prime}\) and (2) masses defined by the total topic mixed memberships \(\sum_{i}\gamma_{ik}\) and \(\sum_{i}\gamma_{ik^\prime}\). If topics have similar mixed membership weight and similar topic \(\beta\), then they will be given high transport alignment weight.
transport_weights(gammas, betas, reg = 0.1, ...)
(required) A list of length two, containing the mixed
membership matrices (a matrix
of dimension n-samples by k-topics) to
compare. The number of columns may be different, but the number of samples
must be equal.
(required). A list of length two, containing the topic matrices
(a matrix
of dimension k-topics by d-dimensions).) The number of rows
may be different, but the number of columns must remain fixed.
(optional) How much regularization to use in the Sinkhorn optimal transport algorithm? Defaults to 0.1.
(optional) Other keyword arguments. Not used here, but included for consistency with other weight functions.
products A data.frame
giving the product similarity of each
pair of topics across the two input matrices.
library(purrr)
data <- rmultinom(10, 20, rep(0.1, 20))
lda_params <- setNames(map(1:5, ~ list(k = .)), 1:5)
lda_models <- run_lda_models(data, lda_params)
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
#> Using default value 'VEM' for 'method' LDA parameter.
gammas <- list(lda_models[[3]]$gamma, lda_models[[5]]$gamma)
betas <- list(lda_models[[3]]$beta, lda_models[[5]]$beta)
transport_weights(gammas, betas)
#> # A tibble: 15 × 3
#> k k_next weight
#> <chr> <chr> <dbl>
#> 1 1 1 3.19
#> 2 1 2 0.227
#> 3 1 3 0.392
#> 4 1 4 2.86
#> 5 1 5 0.00343
#> 6 2 1 0.416
#> 7 2 2 1.12
#> 8 2 3 1.42
#> 9 2 4 0.304
#> 10 2 5 3.40
#> 11 3 1 0.398
#> 12 3 2 2.65
#> 13 3 3 2.18
#> 14 3 4 0.837
#> 15 3 5 0.597