Skip to content

Commit

Permalink
coll/accelerator: duplicate reduce code for reduce_local
Browse files Browse the repository at this point in the history
Signed-off-by: Akshay Venkatesh <[email protected]>
  • Loading branch information
Akshay-Venkatesh committed Aug 27, 2024
1 parent 57485f4 commit 7f6f788
Showing 1 changed file with 54 additions and 13 deletions.
67 changes: 54 additions & 13 deletions ompi/mca/coll/accelerator/coll_accelerator_reduce.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2004-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
Expand Down Expand Up @@ -36,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
mca_coll_base_module_t *module)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
int rank = (comm == NULL) ? -1 : ompi_comm_rank(comm);
int rank = ompi_comm_rank(comm);
ptrdiff_t gap;
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
size_t bufsize;
Expand Down Expand Up @@ -71,15 +71,9 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
rbuf2 = rbuf; /* save away original buffer */
rbuf = rbuf1 - gap;
}

if ((comm == NULL) && (root == -1)) {
ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
rc = OMPI_SUCCESS;
} else {
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
dtype, op, root, comm,
s->c_coll.coll_reduce_module);
}
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
dtype, op, root, comm,
s->c_coll.coll_reduce_module);

if (NULL != sbuf1) {
free(sbuf1);
Expand All @@ -98,6 +92,53 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
struct ompi_op_t *op,
mca_coll_base_module_t *module)
{
return mca_coll_accelerator_reduce(sbuf, rbuf, count, dtype, op, -1, NULL,
module);
ptrdiff_t gap;
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
size_t bufsize;
int rc;

bufsize = opal_datatype_span(&dtype->super, count, &gap);

rc = mca_coll_accelerator_check_buf((void *)sbuf);
if (rc < 0) {
return rc;
}

if ((MPI_IN_PLACE != sbuf) && (rc > 0)) {
sbuf1 = (char*)malloc(bufsize);
if (NULL == sbuf1) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
mca_coll_accelerator_memcpy(sbuf1, sbuf, bufsize);
sbuf = sbuf1 - gap;
}

rc = mca_coll_accelerator_check_buf(rbuf);
if (rc < 0) {
return rc;
}

if (rc > 0) {
rbuf1 = (char*)malloc(bufsize);
if (NULL == rbuf1) {
if (NULL != sbuf1) free(sbuf1);
return OMPI_ERR_OUT_OF_RESOURCE;
}
mca_coll_accelerator_memcpy(rbuf1, rbuf, bufsize);
rbuf2 = rbuf; /* save away original buffer */
rbuf = rbuf1 - gap;
}

ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
rc = OMPI_SUCCESS;

if (NULL != sbuf1) {
free(sbuf1);
}
if (NULL != rbuf1) {
rbuf = rbuf2;
mca_coll_accelerator_memcpy(rbuf, rbuf1, bufsize);
free(rbuf1);
}
return rc;
}

0 comments on commit 7f6f788

Please sign in to comment.