Skip to content

Commit

Permalink
add code for bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
jianxiaoyang committed Jul 25, 2023
1 parent d20b851 commit 8d52884
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 8 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export(listOpenCLDevices)
export(meanLinearPredictor)
export(mse)
export(readCyclopsData)
export(runBootstrap)
export(setOpenCLDevice)
export(simulateCyclopsData)
export(splitTime)
Expand Down
15 changes: 15 additions & 0 deletions R/ModelFit.R
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,21 @@ getSEs <- function(object, covariates) {
ses
}


#' @title Run Bootstrap for Cyclops model parameter
#'
#' @param object A fitted Cyclops model object
#' @param outFileName Character: Output file name
#' @param treatmentId Character: variable to output
#' @param replicates Numeric: number of bootstrap samples
#'
#' @export
runBootstrap <- function(object, outFileName, treatmentId, replicates) {
.checkInterface(object$cyclopsData, testOnly = TRUE)
bs <- .cyclopsRunBootstrap(object$cyclopsData$cyclopsInterfacePtr, outFileName, treatmentId, replicates)
bs
}

#' @title Confidence intervals for Cyclops model parameters
#'
#' @description
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@
.Call(`_Cyclops_cyclopsFitModel`, inRcppCcdInterface)
}

.cyclopsRunBootstrap <- function(inRcppCcdInterface, outFileName, treatmentId, replicates) {
.Call(`_Cyclops_cyclopsRunBootstrap`, inRcppCcdInterface, outFileName, treatmentId, replicates)
}

.cyclopsLogModel <- function(inRcppCcdInterface) {
.Call(`_Cyclops_cyclopsLogModel`, inRcppCcdInterface)
}
Expand Down
20 changes: 20 additions & 0 deletions man/runBootstrap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions src/RcppCyclopsInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,32 @@ List cyclopsFitModel(SEXP inRcppCcdInterface) {
return list;
}

// [[Rcpp::export(".cyclopsRunBootstrap")]]
List cyclopsRunBootstrap(SEXP inRcppCcdInterface, const std::string& outFileName, std::string& treatmentId, int replicates) {
using namespace bsccs;

XPtr<RcppCcdInterface> interface(inRcppCcdInterface);
interface->getArguments().doBootstrap = true;
interface->getArguments().outFileName = outFileName;
interface->getArguments().replicates = replicates;

// Save parameter point-estimates
std::vector<double> savedBeta;
for (int j = 0; j < interface->getCcd().getBetaSize(); ++j) {
savedBeta.push_back(interface->getCcd().getBeta(j));
} // TODO Handle above work in interface.runBootstrap
double timeUpdate = interface->runBoostrap(savedBeta, treatmentId);

interface->diagnoseModel(0.0, 0.0);

List list = List::create(
Rcpp::Named("interface")=interface,
Rcpp::Named("timeFit")=timeUpdate
);
RcppCcdInterface::appendRList(list, interface->getResult());
return list;
}

// [[Rcpp::export(".cyclopsLogModel")]]
List cyclopsLogModel(SEXP inRcppCcdInterface) {
using namespace bsccs;
Expand Down
4 changes: 2 additions & 2 deletions src/RcppCyclopsInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class RcppCcdInterface : public CcdInterface {
return CcdInterface::runCrossValidation(ccd, modelData);
}

double runBoostrap(std::vector<double>& savedBeta) {
return CcdInterface::runBoostrap(ccd, modelData, savedBeta);
double runBoostrap(std::vector<double>& savedBeta, std::string& treatmentId) {
return CcdInterface::runBoostrap(ccd, modelData, savedBeta, treatmentId);
}

void logResultsToFile(const std::string& fileName, bool withASE);
Expand Down
15 changes: 15 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// cyclopsRunBootstrap
List cyclopsRunBootstrap(SEXP inRcppCcdInterface, const std::string& outFileName, std::string& treatmentId, int replicates);
RcppExport SEXP _Cyclops_cyclopsRunBootstrap(SEXP inRcppCcdInterfaceSEXP, SEXP outFileNameSEXP, SEXP treatmentIdSEXP, SEXP replicatesSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< SEXP >::type inRcppCcdInterface(inRcppCcdInterfaceSEXP);
Rcpp::traits::input_parameter< const std::string& >::type outFileName(outFileNameSEXP);
Rcpp::traits::input_parameter< std::string& >::type treatmentId(treatmentIdSEXP);
Rcpp::traits::input_parameter< int >::type replicates(replicatesSEXP);
rcpp_result_gen = Rcpp::wrap(cyclopsRunBootstrap(inRcppCcdInterface, outFileName, treatmentId, replicates));
return rcpp_result_gen;
END_RCPP
}
// cyclopsLogModel
List cyclopsLogModel(SEXP inRcppCcdInterface);
RcppExport SEXP _Cyclops_cyclopsLogModel(SEXP inRcppCcdInterfaceSEXP) {
Expand Down Expand Up @@ -820,6 +834,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_Cyclops_cyclopsSetControl", (DL_FUNC) &_Cyclops_cyclopsSetControl, 23},
{"_Cyclops_cyclopsRunCrossValidationl", (DL_FUNC) &_Cyclops_cyclopsRunCrossValidationl, 1},
{"_Cyclops_cyclopsFitModel", (DL_FUNC) &_Cyclops_cyclopsFitModel, 1},
{"_Cyclops_cyclopsRunBootstrap", (DL_FUNC) &_Cyclops_cyclopsRunBootstrap, 4},
{"_Cyclops_cyclopsLogModel", (DL_FUNC) &_Cyclops_cyclopsLogModel, 1},
{"_Cyclops_cyclopsInitializeModel", (DL_FUNC) &_Cyclops_cyclopsInitializeModel, 4},
{"_Cyclops_listOpenCLDevices", (DL_FUNC) &_Cyclops_listOpenCLDevices, 0},
Expand Down
14 changes: 11 additions & 3 deletions src/cyclops/CcdInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,21 +697,29 @@ SelectorType CcdInterface::getDefaultSelectorTypeOrOverride(SelectorType selecto
double CcdInterface::runBoostrap(
CyclicCoordinateDescent *ccd,
AbstractModelData *modelData,
std::vector<double>& savedBeta) {
std::vector<double>& savedBeta,
std::string& treatmentId) {
struct timeval time1, time2;
gettimeofday(&time1, NULL);

auto selectorType = getDefaultSelectorTypeOrOverride(
arguments.crossValidation.selectorType, modelData->getModelType());

BootstrapSelector selector(arguments.replicates, modelData->getPidVectorSTL(),
vector<int> ids;
if (selectorType == SelectorType::BY_ROW) {
std::cout << "runBoostrap SelectorType::BY_ROW \n";
ids.resize(modelData->getNumberOfRows());
std::iota(ids.begin(), ids.end(), 0);
}
BootstrapSelector selector(arguments.replicates, selectorType == SelectorType::BY_ROW ? ids : modelData->getPidVectorSTL(),
selectorType, arguments.seed, logger, error);
BootstrapDriver driver(arguments.replicates, modelData, logger, error);

driver.drive(*ccd, selector, arguments);
gettimeofday(&time2, NULL);

driver.logResults(arguments, savedBeta, ccd->getConditionId());
// driver.logResults(arguments, savedBeta, ccd->getConditionId());
driver.logHR(arguments, savedBeta, treatmentId);
return calculateSeconds(time1, time2);
}

Expand Down
3 changes: 2 additions & 1 deletion src/cyclops/CcdInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ class CcdInterface {
double runBoostrap(
CyclicCoordinateDescent *ccd,
AbstractModelData *modelData,
std::vector<double>& savedBeta);
std::vector<double>& savedBeta,
std::string& treatmentId);

void setDefaultArguments();

Expand Down
56 changes: 56 additions & 0 deletions src/cyclops/drivers/BootstrapDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,60 @@ void BootstrapDriver::logResults(const CCDArguments& arguments, std::vector<doub
outLog.close();
}

void BootstrapDriver::logHR(const CCDArguments& arguments, std::vector<double>& savedBeta, std::string treatmentId) {

// int j = J-1;
int tId = 0;
while (modelData->getColumnLabel(tId) != treatmentId) tId++;

ofstream outLog(arguments.outFileName.c_str());
if (!outLog) {
std::ostringstream stream;
stream << "Unable to open log file: " << arguments.bsFileName;
error->throwError(stream);
}

string sep(","); // TODO Make option

// Raw estimates
// for(rvector::iterator it = estimates[j]->begin(); it != estimates[j]->end(); ++it) outLog << *it << endl;

// Stats
outLog << "Drug_concept_id" << sep << "Condition_concept_id" << sep <<
"score" << sep << "standard_error" << sep << "bs_mean" << sep << "bs_lower" << sep <<
"bs_upper" << sep << "bs_prob0" << endl;

for (int j = tId; j < J; ++j) {
outLog << modelData->getColumnLabel(j) <<
sep << treatmentId << sep;

double mean = 0.0;
double var = 0.0;
double prob0 = 0.0;
for (rvector::iterator it = estimates[j]->begin(); it != estimates[j]->end(); ++it) {
mean += *it;
var += *it * *it;
if (*it == 0.0) {
prob0 += 1.0;
}
}

double size = static_cast<double>(estimates[j]->size());
mean /= size;
var = (var / size) - (mean * mean);
prob0 /= size;

sort(estimates[j]->begin(), estimates[j]->end());
int offsetLower = static_cast<int>(size * 0.025);
int offsetUpper = static_cast<int>(size * 0.975);

double lower = *(estimates[j]->begin() + offsetLower);
double upper = *(estimates[j]->begin() + offsetUpper);

outLog << savedBeta[j] << sep;
outLog << std::sqrt(var) << sep << mean << sep << lower << sep << upper << sep << prob0 << endl;
}
outLog.close();
}

} // namespace
2 changes: 2 additions & 0 deletions src/cyclops/drivers/BootstrapDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class BootstrapDriver : public AbstractDriver {

void logResults(const CCDArguments& arguments, std::vector<double>& savedBeta, std::string conditionId);

void logHR(const CCDArguments& arguments, std::vector<double>& savedBeta, std::string treatmentId);

private:
const int replicates;
AbstractModelData* modelData;
Expand Down
8 changes: 6 additions & 2 deletions src/cyclops/drivers/BootstrapSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,20 @@ void BootstrapSelector::permute() {

// Get non-excluded indices
int N_new = indicesIncluded.size();
if (type == SelectorType::BY_PID) {
// if (type == SelectorType::BY_PID) {
std::uniform_int_distribution<int> uniform(0, N_new - 1);
for (int i = 0; i < N_new; i++) {
int ind = uniform(prng);
int draw = indicesIncluded[ind];
selectedSet.insert(draw);
}
/*
} else {
std::ostringstream stream;
stream << "BootstrapSelector::permute is not yet implemented.";
error->throwError(stream);
}
*/
}

void BootstrapSelector::getWeights(int batch, std::vector<double>& weights) {
Expand All @@ -80,16 +82,18 @@ void BootstrapSelector::getWeights(int batch, std::vector<double>& weights) {
return;
}

if (type == SelectorType::BY_PID) {
// if (type == SelectorType::BY_PID) {
for (size_t k = 0; k < K; k++) {
int count = selectedSet.count(ids.at(k));
weights[k] = static_cast<real>(count);
}
/*
} else {
std::ostringstream stream;
stream << "BootstrapSelector::getWeights is not yet implemented.";
error->throwError(stream);
}
*/
}

void BootstrapSelector::getComplement(std::vector<double>& weights) {
Expand Down

0 comments on commit 8d52884

Please sign in to comment.