-
-
Notifications
You must be signed in to change notification settings - Fork 369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved rhat diagnostic #3266
Improved rhat diagnostic #3266
Changes from 17 commits
88173cc
98450d7
b09f512
dcd21f2
c20ce61
0e8379f
5630c51
50617bd
7c5880f
dc130f8
3433d9c
4b58653
3aeab3f
30333ce
77d875f
6221540
da10f8f
c3b1101
263ac22
b6ef4bb
51ad447
71bfbcf
c371c21
385c80b
51a2504
ba2b6f5
981eedc
de941ba
971aed7
61c4c6c
9dae7a0
51db135
f12f259
3cb8e97
da5bb8d
b3631be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -8,6 +8,7 @@ | |||||||||||||||||||||||||||||||||
#include <boost/accumulators/statistics/stats.hpp> | ||||||||||||||||||||||||||||||||||
#include <boost/accumulators/statistics/mean.hpp> | ||||||||||||||||||||||||||||||||||
#include <boost/accumulators/statistics/variance.hpp> | ||||||||||||||||||||||||||||||||||
#include <boost/math/distributions/normal.hpp> | ||||||||||||||||||||||||||||||||||
#include <algorithm> | ||||||||||||||||||||||||||||||||||
#include <cmath> | ||||||||||||||||||||||||||||||||||
#include <vector> | ||||||||||||||||||||||||||||||||||
|
@@ -16,6 +17,150 @@ | |||||||||||||||||||||||||||||||||
namespace stan { | ||||||||||||||||||||||||||||||||||
namespace analyze { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes normalized average ranks for draws. Transforming them to normal | ||||||||||||||||||||||||||||||||||
* scores using inverse normal transformation and a fractional offset. Based on | ||||||||||||||||||||||||||||||||||
* paper https://arxiv.org/abs/1903.08008 | ||||||||||||||||||||||||||||||||||
* @param draws stores chains in columns | ||||||||||||||||||||||||||||||||||
* @return normal scores for average ranks of draws | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Eigen::MatrixXd rank_transform(const Eigen::MatrixXd& draws) { | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the error on jenkins. This makes it so that there are not multiple definitions for different translation units
Suggested change
|
||||||||||||||||||||||||||||||||||
const Eigen::Index rows = draws.rows(); | ||||||||||||||||||||||||||||||||||
const Eigen::Index cols = draws.cols(); | ||||||||||||||||||||||||||||||||||
const Eigen::Index size = rows * cols; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
std::vector<std::pair<double, int>> value_with_index(size); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (Eigen::Index i = 0; i < size; ++i) { | ||||||||||||||||||||||||||||||||||
value_with_index[i] = {draws(i), i}; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
std::sort(value_with_index.begin(), value_with_index.end()); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Eigen::MatrixXd rankMatrix = Eigen::MatrixXd::Zero(rows, cols); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use CamelCase for template parameters and
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check your code as this should apply everywhere like |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
// Assigning average ranks | ||||||||||||||||||||||||||||||||||
for (Eigen::Index i = 0; i < size; ++i) { | ||||||||||||||||||||||||||||||||||
// Handle ties by averaging ranks | ||||||||||||||||||||||||||||||||||
Eigen::Index j = i + 1; | ||||||||||||||||||||||||||||||||||
double sumRanks = j; | ||||||||||||||||||||||||||||||||||
Eigen::Index count = 1; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
while (j < size && value_with_index[j].first == value_with_index[i].first) { | ||||||||||||||||||||||||||||||||||
sumRanks += j + 1; // Rank starts from 1 | ||||||||||||||||||||||||||||||||||
++j; | ||||||||||||||||||||||||||||||||||
++count; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
double avgRank = sumRanks / count; | ||||||||||||||||||||||||||||||||||
boost::math::normal_distribution<double> dist; | ||||||||||||||||||||||||||||||||||
for (std::size_t k = i; k < j; ++k) { | ||||||||||||||||||||||||||||||||||
Eigen::Index index = value_with_index[k].second; | ||||||||||||||||||||||||||||||||||
double p = (avgRank - 3.0 / 8.0) / (size - 2.0 * 3.0 / 8.0 + 1.0); | ||||||||||||||||||||||||||||||||||
rankMatrix(index) = boost::math::quantile(dist, p); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
i = j - 1; // Skip over tied elements | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
return rankMatrix; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes square root of marginal posterior variance of the estimand by the | ||||||||||||||||||||||||||||||||||
* weigted average of within-chain variance W and between-chain variance B. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* @param draws stores chains in columns | ||||||||||||||||||||||||||||||||||
* @return square root of ((N-1)/N)W + B/N | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
inline double rhat(const Eigen::MatrixXd& draws) { | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Suggested change
|
||||||||||||||||||||||||||||||||||
const Eigen::Index num_chains = draws.cols(); | ||||||||||||||||||||||||||||||||||
const Eigen::Index num_draws = draws.rows(); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Eigen::VectorXd chain_mean(num_chains); | ||||||||||||||||||||||||||||||||||
chain_mean = draws.colwise().mean(); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
double total_mean = chain_mean.mean(); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
double var_between = num_draws | ||||||||||||||||||||||||||||||||||
* (chain_mean.array() - total_mean).square().sum() | ||||||||||||||||||||||||||||||||||
/ (num_chains - 1); | ||||||||||||||||||||||||||||||||||
double var_sum = 0; | ||||||||||||||||||||||||||||||||||
for (Eigen::Index col = 0; col < num_chains; ++col) { | ||||||||||||||||||||||||||||||||||
var_sum += (draws.col(col).array() - chain_mean(col)).square().sum() | ||||||||||||||||||||||||||||||||||
/ (num_draws - 1); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
double var_within = var_sum / num_chains; | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little more advanced but we can also do this like the below. I added comments to explain the Eigen parts but delete them if you change this. Also in the below I assume you change
Suggested change
|
||||||||||||||||||||||||||||||||||
return sqrt((var_between / var_within + num_draws - 1) / num_draws); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the potential scale reduction (Rhat) using rank based diagnostic for | ||||||||||||||||||||||||||||||||||
* the specified parameter across all kept samples. Based on paper | ||||||||||||||||||||||||||||||||||
* https://arxiv.org/abs/1903.08008 | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* Current implementation assumes draws are stored in contiguous | ||||||||||||||||||||||||||||||||||
* blocks of memory. Chains are trimmed from the back to match the | ||||||||||||||||||||||||||||||||||
* length of the shortest chain. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* @param draws stores pointers to arrays of chains | ||||||||||||||||||||||||||||||||||
* @param sizes stores sizes of chains | ||||||||||||||||||||||||||||||||||
* @return potential scale reduction for the specified parameter | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
inline double compute_potential_scale_reduction_rank( | ||||||||||||||||||||||||||||||||||
std::vector<const double*> draws, std::vector<size_t> sizes) { | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rename these to be a little more clear.
Suggested change
|
||||||||||||||||||||||||||||||||||
int num_chains = sizes.size(); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
or use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is only the first size checked to see if it is zero and then we return a NaN? If one chain failed we should still be able to use information from all of the other chains. Looking at the rest of the code, unless there is a math reason to not ignore zero sized chains I think we should just prune them std::vector<const double*> nonzero_chains_begins;
std::vector<std::size_t> nonzero_chain_sizes;
for (int i = 0; i < chain_sizes.size(); ++i) {
if (!chain_sizes[i]) {
nonzero_chains_begin.push_back(chain_begins[i]);
nonzero_chains_sizes.push_back(chain_sizes[i]);
}
}
if (!nonzero_chains_sizes.size()) {
return std::numeric_limits<double>::quiet_NaN();
} |
||||||||||||||||||||||||||||||||||
size_t num_draws = sizes[0]; | ||||||||||||||||||||||||||||||||||
if (num_draws == 0) { | ||||||||||||||||||||||||||||||||||
return std::numeric_limits<double>::quiet_NaN(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also should |
||||||||||||||||||||||||||||||||||
for (int chain = 1; chain < num_chains; ++chain) { | ||||||||||||||||||||||||||||||||||
num_draws = std::min(num_draws, sizes[chain]); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
// check if chains are constant; all equal to first draw's value | ||||||||||||||||||||||||||||||||||
bool are_all_const = false; | ||||||||||||||||||||||||||||||||||
Eigen::VectorXd init_draw = Eigen::VectorXd::Zero(num_chains); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (int chain = 0; chain < num_chains; chain++) { | ||||||||||||||||||||||||||||||||||
Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, 1>> draw( | ||||||||||||||||||||||||||||||||||
draws[chain], sizes[chain]); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (int n = 0; n < num_draws; n++) { | ||||||||||||||||||||||||||||||||||
if (!std::isfinite(draw(n))) { | ||||||||||||||||||||||||||||||||||
return std::numeric_limits<double>::quiet_NaN(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
aleksgorica marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
init_draw(chain) = draw(0); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (draw.isApproxToConstant(draw(0))) { | ||||||||||||||||||||||||||||||||||
are_all_const |= true; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (are_all_const) { | ||||||||||||||||||||||||||||||||||
// If all chains are constant then return NaN | ||||||||||||||||||||||||||||||||||
// if they all equal the same constant value | ||||||||||||||||||||||||||||||||||
if (init_draw.isApproxToConstant(init_draw(0))) { | ||||||||||||||||||||||||||||||||||
return std::numeric_limits<double>::quiet_NaN(); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it is fine if each chain is constant, but each one is a different value? tbc I'm asking because idk if that is how the paper is written or not. I suppose this makes sense in the case of many short chains There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are correct. The current implementation fails if different chains are constant. For example, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing something intelligent around constant chains would be a big improvement on our current NaN behavior. But I'm not sure what that is as there's not a number that makes sense as the ESS. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If all chains have the same constant value, we can't make the difference between all chains being stuck or variable actually being constant (e.g. diagonal of correlation matrix) as Stan doesn't tag the variables. In that case diagnostics in R return NA. If the chains have different constant values, then the variable can't be a true constant, and Rhat Inf is fine. |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Eigen::MatrixXd matrix(num_draws, num_chains); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (int col = 0; col < num_chains; ++col) { | ||||||||||||||||||||||||||||||||||
for (int row = 0; row < num_draws; ++row) { | ||||||||||||||||||||||||||||||||||
matrix(row, col) = draws[col][row]; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See other comment about moving this up to the other loop
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
double rhat_bulk = rhat(rank_transform(matrix)); | ||||||||||||||||||||||||||||||||||
double rhat_tail = rhat(rank_transform( | ||||||||||||||||||||||||||||||||||
(matrix.array() - math::quantile(matrix.reshaped(), 0.5)).abs())); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
return std::max(rhat_bulk, rhat_tail); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question for @avehtari Do we want to just return the max or should we return a pair so the user can see the bulk and tail rhats? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is useful for the user to know both. A diagnostic message could be simplified by reporting if the max of these is too low, but otherwise I would prefer that both would be available for the user. Making them both available does change the io via csv and changing csv structures need to be considered carefully There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay then @aleksgorica can you have this return back an |
||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the potential scale reduction (Rhat) for the specified | ||||||||||||||||||||||||||||||||||
* parameter across all kept samples. | ||||||||||||||||||||||||||||||||||
|
@@ -31,6 +176,7 @@ namespace analyze { | |||||||||||||||||||||||||||||||||
* @param sizes stores sizes of chains | ||||||||||||||||||||||||||||||||||
* @return potential scale reduction for the specified parameter | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
inline double compute_potential_scale_reduction( | ||||||||||||||||||||||||||||||||||
SteveBronder marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
std::vector<const double*> draws, std::vector<size_t> sizes) { | ||||||||||||||||||||||||||||||||||
int num_chains = sizes.size(); | ||||||||||||||||||||||||||||||||||
|
@@ -71,34 +217,39 @@ inline double compute_potential_scale_reduction( | |||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
using boost::accumulators::accumulator_set; | ||||||||||||||||||||||||||||||||||
using boost::accumulators::stats; | ||||||||||||||||||||||||||||||||||
using boost::accumulators::tag::mean; | ||||||||||||||||||||||||||||||||||
using boost::accumulators::tag::variance; | ||||||||||||||||||||||||||||||||||
Eigen::MatrixXd matrix(num_draws, num_chains); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Eigen::VectorXd chain_mean(num_chains); | ||||||||||||||||||||||||||||||||||
accumulator_set<double, stats<variance>> acc_chain_mean; | ||||||||||||||||||||||||||||||||||
Eigen::VectorXd chain_var(num_chains); | ||||||||||||||||||||||||||||||||||
double unbiased_var_scale = num_draws / (num_draws - 1.0); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (int chain = 0; chain < num_chains; ++chain) { | ||||||||||||||||||||||||||||||||||
accumulator_set<double, stats<mean, variance>> acc_draw; | ||||||||||||||||||||||||||||||||||
for (int n = 0; n < num_draws; ++n) { | ||||||||||||||||||||||||||||||||||
acc_draw(draws[chain][n]); | ||||||||||||||||||||||||||||||||||
for (int col = 0; col < num_chains; ++col) { | ||||||||||||||||||||||||||||||||||
for (int row = 0; row < num_draws; ++row) { | ||||||||||||||||||||||||||||||||||
matrix(row, col) = draws[col][row]; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
chain_mean(chain) = boost::accumulators::mean(acc_draw); | ||||||||||||||||||||||||||||||||||
acc_chain_mean(chain_mean(chain)); | ||||||||||||||||||||||||||||||||||
chain_var(chain) | ||||||||||||||||||||||||||||||||||
= boost::accumulators::variance(acc_draw) * unbiased_var_scale; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
double var_between = num_draws * boost::accumulators::variance(acc_chain_mean) | ||||||||||||||||||||||||||||||||||
* num_chains / (num_chains - 1); | ||||||||||||||||||||||||||||||||||
double var_within = chain_var.mean(); | ||||||||||||||||||||||||||||||||||
return rhat(matrix); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
// rewrote [(n-1)*W/n + B/n]/W as (n-1+ B/W)/n | ||||||||||||||||||||||||||||||||||
return sqrt((var_between / var_within + num_draws - 1) / num_draws); | ||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the potential scale reduction (Rhat) using rank based diagnostic for | ||||||||||||||||||||||||||||||||||
* the specified parameter across all kept samples. Based on paper | ||||||||||||||||||||||||||||||||||
* https://arxiv.org/abs/1903.08008 | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* See more details in Stan reference manual section "Potential | ||||||||||||||||||||||||||||||||||
* Scale Reduction". http://mc-stan.org/users/documentation | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* Current implementation assumes draws are stored in contiguous | ||||||||||||||||||||||||||||||||||
* blocks of memory. Chains are trimmed from the back to match the | ||||||||||||||||||||||||||||||||||
* length of the shortest chain. Argument size will be broadcast to | ||||||||||||||||||||||||||||||||||
* same length as draws. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* @param draws stores pointers to arrays of chains | ||||||||||||||||||||||||||||||||||
* @param size stores sizes of chains | ||||||||||||||||||||||||||||||||||
* @return potential scale reduction for the specified parameter | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
inline double compute_potential_scale_reduction_rank( | ||||||||||||||||||||||||||||||||||
std::vector<const double*> draws, size_t size) { | ||||||||||||||||||||||||||||||||||
int num_chains = draws.size(); | ||||||||||||||||||||||||||||||||||
std::vector<size_t> sizes(num_chains, size); | ||||||||||||||||||||||||||||||||||
return compute_potential_scale_reduction_rank(draws, sizes); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
|
@@ -124,6 +275,40 @@ inline double compute_potential_scale_reduction( | |||||||||||||||||||||||||||||||||
return compute_potential_scale_reduction(draws, sizes); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the potential scale reduction (Rhat) using rank based diagnostic for | ||||||||||||||||||||||||||||||||||
* the specified parameter across all kept samples. Based on paper | ||||||||||||||||||||||||||||||||||
* https://arxiv.org/abs/1903.08008 | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* When the number of total draws N is odd, the (N+1)/2th draw is ignored. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* See more details in Stan reference manual section "Potential | ||||||||||||||||||||||||||||||||||
* Scale Reduction". http://mc-stan.org/users/documentation | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* Current implementation assumes draws are stored in contiguous | ||||||||||||||||||||||||||||||||||
* blocks of memory. Chains are trimmed from the back to match the | ||||||||||||||||||||||||||||||||||
* length of the shortest chain. | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You use |
||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* @param draws stores pointers to arrays of chains | ||||||||||||||||||||||||||||||||||
* @param sizes stores sizes of chains | ||||||||||||||||||||||||||||||||||
* @return potential scale reduction for the specified parameter | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
inline double compute_split_potential_scale_reduction_rank( | ||||||||||||||||||||||||||||||||||
std::vector<const double*> draws, std::vector<size_t> sizes) { | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want these arguments to come in as constant references. As written this will make a hard copy of the input vectors when you call this function. Making the arguments references (
Suggested change
We want containers (like This comment applies to all the function signatures you added here |
||||||||||||||||||||||||||||||||||
int num_chains = sizes.size(); | ||||||||||||||||||||||||||||||||||
size_t num_draws = sizes[0]; | ||||||||||||||||||||||||||||||||||
for (int chain = 1; chain < num_chains; ++chain) { | ||||||||||||||||||||||||||||||||||
num_draws = std::min(num_draws, sizes[chain]); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
std::vector<const double*> split_draws = split_chains(draws, sizes); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
double half = num_draws / 2.0; | ||||||||||||||||||||||||||||||||||
std::vector<size_t> half_sizes(2 * num_chains, std::floor(half)); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make it more clear you are using floating point division and then taking the floor to get the index
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
return compute_potential_scale_reduction_rank(split_draws, half_sizes); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the split potential scale reduction (Rhat) for the | ||||||||||||||||||||||||||||||||||
* specified parameter across all kept samples. When the number of | ||||||||||||||||||||||||||||||||||
|
@@ -156,6 +341,32 @@ inline double compute_split_potential_scale_reduction( | |||||||||||||||||||||||||||||||||
return compute_potential_scale_reduction(split_draws, half_sizes); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the potential scale reduction (Rhat) using rank based diagnostic for | ||||||||||||||||||||||||||||||||||
* the specified parameter across all kept samples. Based on paper | ||||||||||||||||||||||||||||||||||
* https://arxiv.org/abs/1903.08008 | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* When the number of total draws N is odd, the (N+1)/2th draw is ignored. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* See more details in Stan reference manual section "Potential | ||||||||||||||||||||||||||||||||||
* Scale Reduction". http://mc-stan.org/users/documentation | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* Current implementation assumes draws are stored in contiguous | ||||||||||||||||||||||||||||||||||
* blocks of memory. Chains are trimmed from the back to match the | ||||||||||||||||||||||||||||||||||
* length of the shortest chain. Argument size will be broadcast to | ||||||||||||||||||||||||||||||||||
* same length as draws. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* @param draws stores pointers to arrays of chains | ||||||||||||||||||||||||||||||||||
* @param size stores sizes of chains | ||||||||||||||||||||||||||||||||||
* @return potential scale reduction for the specified parameter | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
inline double compute_split_potential_scale_reduction_rank( | ||||||||||||||||||||||||||||||||||
std::vector<const double*> draws, size_t size) { | ||||||||||||||||||||||||||||||||||
int num_chains = draws.size(); | ||||||||||||||||||||||||||||||||||
std::vector<size_t> sizes(num_chains, size); | ||||||||||||||||||||||||||||||||||
return compute_split_potential_scale_reduction_rank(draws, sizes); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||
* Computes the split potential scale reduction (Rhat) for the | ||||||||||||||||||||||||||||||||||
* specified parameter across all kept samples. When the number of | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.