Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][CELEBORN-1490][CIP-6] Support Flink hybrid shuffle integration with Apache Celeborn #2678

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,60 @@
package org.apache.celeborn.plugin.flink;

import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.shuffle.PartitionDescriptor;
import org.apache.flink.runtime.shuffle.ProducerDescriptor;

import org.apache.celeborn.plugin.flink.utils.FlinkUtils;

public class FlinkResultPartitionInfo {
private final JobID jobID;
private final PartitionDescriptor partitionDescriptor;
private final ProducerDescriptor producerDescriptor;
private final ResultPartitionID resultPartitionId;
private final IntermediateResultPartitionID partitionId;
private final ExecutionAttemptID producerId;

public FlinkResultPartitionInfo(JobID jobID, ResultPartitionID resultPartitionId) {
this.jobID = jobID;
this.resultPartitionId = resultPartitionId;
this.partitionId = resultPartitionId.getPartitionId();
this.producerId = resultPartitionId.getProducerId();
}

public FlinkResultPartitionInfo(
JobID jobID, PartitionDescriptor partitionDescriptor, ProducerDescriptor producerDescriptor) {
this.jobID = jobID;
this.partitionDescriptor = partitionDescriptor;
this.producerDescriptor = producerDescriptor;
this.resultPartitionId =
new ResultPartitionID(
partitionDescriptor.getPartitionId(), producerDescriptor.getProducerExecutionId());
this.partitionId = partitionDescriptor.getPartitionId();
this.producerId = producerDescriptor.getProducerExecutionId();
}

public ResultPartitionID getResultPartitionId() {
return new ResultPartitionID(
partitionDescriptor.getPartitionId(), producerDescriptor.getProducerExecutionId());
return resultPartitionId;
}

public String getShuffleId() {
return FlinkUtils.toShuffleId(jobID, partitionDescriptor.getResultId());
return FlinkUtils.toShuffleId(jobID, partitionId.getIntermediateDataSetID());
}

public int getTaskId() {
return partitionDescriptor.getPartitionId().getPartitionNumber();
return partitionId.getPartitionNumber();
}

public String getAttemptId() {
return FlinkUtils.toAttemptId(producerDescriptor.getProducerExecutionId());
return FlinkUtils.toAttemptId(producerId);
}

@Override
public String toString() {
final StringBuilder sb = new StringBuilder("FlinkResultPartitionInfo{");
sb.append("jobID=").append(jobID);
sb.append(", partitionDescriptor=").append(partitionDescriptor.getPartitionId());
sb.append(", producerDescriptor=").append(producerDescriptor.getProducerExecutionId());
sb.append(", resultPartitionId=").append(resultPartitionId);
sb.append(", partitionDescriptor=").append(partitionId);
sb.append(", producerDescriptor=").append(producerId);
sb.append('}');
return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ public void open(int initialCredit) {
try {
bufferStream =
client.readBufferedPartition(
shuffleId, partitionId, subPartitionIndexStart, subPartitionIndexEnd);
shuffleId, partitionId, subPartitionIndexStart, subPartitionIndexEnd, false);
bufferStream.open(
RemoteBufferStreamReader.this::requestBuffer, initialCredit, messageConsumer);
RemoteBufferStreamReader.this::requestBuffer, initialCredit, messageConsumer, null);
} catch (Exception e) {
logger.warn("Failed to open stream and report to flink framework. ", e);
messageConsumer.accept(new TransportableError(0L, e));
Expand Down Expand Up @@ -158,6 +158,6 @@ public void dataReceived(ReadData readData) {
public void onStreamEnd(BufferStreamEnd streamEnd) {
long streamId = streamEnd.getStreamId();
logger.debug("Buffer stream reader get stream end for {}", streamId);
bufferStream.moveToNextPartitionIfPossible(streamId);
bufferStream.moveToNextPartitionIfPossible(streamId, null, (stream, subPartitionId) -> {});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.celeborn.common.exception.DriverChangedException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
Expand Down Expand Up @@ -207,13 +208,13 @@ FlinkShuffleClientImpl getShuffleClient() {
}

/** Writes a piece of data to a subpartition. */
public void write(ByteBuf byteBuf, int subIdx) {
public void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
try {
flinkShuffleClient.pushDataToLocation(
shuffleId,
mapId,
attemptId,
subIdx,
bufferHeader.getSubPartitionId(),
io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer()),
partitionLocation,
() -> byteBuf.release());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ public BufferHeader(
this.size = size;
}

public void setSubPartitionId(int subPartitionId) {
this.subPartitionId = subPartitionId;
}

public int getSubPartitionId() {
return subPartitionId;
}

public Buffer.DataType getDataType() {
return dataType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.util.FlinkRuntimeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -41,14 +43,15 @@ public interface BiConsumerWithException<T, U, E extends Throwable> {
void accept(T var1, U var2) throws E;
}

private final BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler;
private final BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException>
ripeBufferHandler;

private Buffer cachedBuffer;

private int currentSubIdx = -1;

public BufferPacker(
BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler) {
BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException> ripeBufferHandler) {
this.ripeBufferHandler = ripeBufferHandler;
}

Expand All @@ -71,7 +74,8 @@ public void process(Buffer buffer, int subIdx) throws InterruptedException {
int targetSubIdx = currentSubIdx;
currentSubIdx = subIdx;
logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes());
handleRipeBuffer(dumpedBuffer, targetSubIdx);
handleRipeBuffer(
dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed());
} else {
/**
* this is an optimization. if cachedBuffer can contain other buffer, then other buffer can
Expand All @@ -95,7 +99,8 @@ public void process(Buffer buffer, int subIdx) throws InterruptedException {
cachedBuffer = buffer;
logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes());

handleRipeBuffer(dumpedBuffer, currentSubIdx);
handleRipeBuffer(
dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed());
}
}
}
Expand All @@ -109,18 +114,34 @@ private void logBufferPack(boolean isDrain, Buffer.DataType dataType, int length
length);
}

public void drain() throws InterruptedException {
public void drain() {
if (cachedBuffer != null) {
logBufferPack(true, cachedBuffer.getDataType(), cachedBuffer.readableBytes());
handleRipeBuffer(cachedBuffer, currentSubIdx);
try {
handleRipeBuffer(
cachedBuffer, currentSubIdx, cachedBuffer.getDataType(), cachedBuffer.isCompressed());
} catch (InterruptedException e) {
throw new FlinkRuntimeException(e);
}
}
cachedBuffer = null;
currentSubIdx = -1;
}

private void handleRipeBuffer(Buffer buffer, int subIdx) throws InterruptedException {
private void handleRipeBuffer(
Buffer buffer, int subIdx, Buffer.DataType dataType, boolean isCompressed)
throws InterruptedException {
if (buffer == null || buffer.readableBytes() == 0) {
return;
}
buffer.setCompressed(false);
ripeBufferHandler.accept(buffer.asByteBuf(), subIdx);
BufferHeader bufferHeader = new BufferHeader(dataType, isCompressed, buffer.getSize());
bufferHeader.setSubPartitionId(subIdx);
ripeBufferHandler.accept(buffer.asByteBuf(), bufferHeader);
}

public boolean isEmpty() {
return cachedBuffer == null || cachedBuffer.readableBytes() == 0;
}

public void close() {
Expand All @@ -134,6 +155,24 @@ public void close() {
public static Queue<Buffer> unpack(ByteBuf byteBuf) {
Queue<Buffer> buffers = new ArrayDeque<>();
try {
if (byteBuf instanceof CompositeByteBuf) {
CompositeByteBuf compositeByteBuf = (CompositeByteBuf) byteBuf;
ByteBuf headerBuffer = compositeByteBuf.component(0).unwrap();
ByteBuf dataBuffer = compositeByteBuf.component(1).unwrap();
dataBuffer.retain();
Utils.checkState(dataBuffer instanceof Buffer, "Illegal buffer type.");
BufferHeader bufferHeader = BufferUtils.getBufferHeaderFromByteBuf(headerBuffer, 0);
Buffer slice = ((Buffer) dataBuffer).readOnlySlice(0, bufferHeader.getSize());
buffers.add(
new UnpackSlicedBuffer(
slice,
bufferHeader.getDataType(),
bufferHeader.isCompressed(),
bufferHeader.getSize()));

return buffers;
}

Utils.checkState(byteBuf instanceof Buffer, "Illegal buffer type.");

Buffer buffer = (Buffer) byteBuf;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ public class FlinkTransportClientFactory extends TransportClientFactory {

private ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
private final int fetchMaxRetries;
private final int bufferSizeBytes;

public FlinkTransportClientFactory(
TransportContext context, int fetchMaxRetries, List<TransportClientBootstrap> bootstraps) {
TransportContext context,
int fetchMaxRetries,
List<TransportClientBootstrap> bootstraps,
int bufferSizeBytes) {
super(context, bootstraps);
bufferSuppliers = JavaUtils.newConcurrentHashMap();
this.fetchMaxRetries = fetchMaxRetries;
this.pooledAllocator = new UnpooledByteBufAllocator(true);
this.bufferSizeBytes = bufferSizeBytes;
}

public TransportClient createClientWithRetry(String remoteHost, int remotePort)
Expand Down Expand Up @@ -82,7 +87,10 @@ public TransportClient createClientWithRetry(String remoteHost, int remotePort)
public TransportClient createClient(String remoteHost, int remotePort)
throws IOException, InterruptedException {
return createClient(
remoteHost, remotePort, -1, new TransportFrameDecoderWithBufferSupplier(bufferSuppliers));
remoteHost,
remotePort,
-1,
new TransportFrameDecoderWithBufferSupplier(bufferSuppliers, bufferSizeBytes));
}

public void registerSupplier(long streamId, Supplier<ByteBuf> supplier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.plugin.flink.buffer.FlinkNettyManagedBuffer;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;

public class MessageDecoderExt {
public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody) {
Expand Down Expand Up @@ -74,6 +75,11 @@ public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody)
streamId = in.readLong();
return new ReadData(streamId);

case SUBPARTITION_READ_DATA:
streamId = in.readLong();
int subPartitionId = in.readInt();
return new SubPartitionReadData(streamId, subPartitionId);

case BACKLOG_ANNOUNCEMENT:
streamId = in.readLong();
int backlog = in.readInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;

public class ReadClientHandler extends BaseMessageHandler {
private static Logger logger = LoggerFactory.getLogger(ReadClientHandler.class);
Expand All @@ -65,6 +66,8 @@ private void processMessageInternal(long streamId, RequestMessage msg) {
} else {
if (msg != null && msg instanceof ReadData) {
((ReadData) msg).getFlinkBuffer().release();
} else if (msg != null && msg instanceof SubPartitionReadData) {
((SubPartitionReadData) msg).getFlinkBuffer().release();
}

logger.warn("Unexpected streamId received: {}", streamId);
Expand All @@ -83,6 +86,10 @@ public void receive(TransportClient client, RequestMessage msg) {
ReadData readData = (ReadData) msg;
processMessageInternal(readData.getStreamId(), readData);
break;
case SUBPARTITION_READ_DATA:
SubPartitionReadData subPartitionReadData = (SubPartitionReadData) msg;
processMessageInternal(subPartitionReadData.getStreamId(), subPartitionReadData);
break;
case BACKLOG_ANNOUNCEMENT:
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
processMessageInternal(backlogAnnouncement.getStreamId(), backlogAnnouncement);
Expand Down
Loading
Loading