From 816a268a285c7bc86758bfde710d5b6149d63cfe Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 3 Sep 2024 10:36:33 +0800 Subject: [PATCH] [CELEBORN-1490][CIP-6] Introduce tier producer in celeborn flink client --- .../plugin/flink/RemoteShuffleOutputGate.java | 5 +- .../plugin/flink/buffer/BufferHeader.java | 9 + .../plugin/flink/buffer/BufferPacker.java | 45 +- .../buffer/ReceivedNoHeaderBufferPacker.java | 112 ++++ .../plugin/flink/utils/BufferUtils.java | 20 + .../plugin/flink/BufferPackSuiteJ.java | 192 ++++++- .../flink/tiered/CelebornTierFactory.java | 12 +- .../tiered/CelebornTierProducerAgent.java | 487 ++++++++++++++++++ 8 files changed, 843 insertions(+), 39 deletions(-) create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java create mode 100644 client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java index d17a182a1bc..f695af14d74 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java @@ -33,6 +33,7 @@ import org.apache.celeborn.common.exception.DriverChangedException; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.plugin.flink.buffer.BufferHeader; import org.apache.celeborn.plugin.flink.buffer.BufferPacker; import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl; import org.apache.celeborn.plugin.flink.utils.BufferUtils; @@ -207,13 +208,13 @@ FlinkShuffleClientImpl getShuffleClient() { } /** Writes a piece of data to a subpartition. */ - public void write(ByteBuf byteBuf, int subIdx) { + public void write(ByteBuf byteBuf, BufferHeader bufferHeader) { try { flinkShuffleClient.pushDataToLocation( shuffleId, mapId, attemptId, - subIdx, + bufferHeader.getSubPartitionId(), io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer()), partitionLocation, () -> byteBuf.release()); diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java index 6dc6350ce51..59e4d5010e0 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java @@ -37,6 +37,11 @@ public BufferHeader(Buffer.DataType dataType, boolean isCompressed, int size) { this(0, 0, 0, size + 2, dataType, isCompressed, size); } + public BufferHeader( + int subPartitionId, Buffer.DataType dataType, boolean isCompressed, int size) { + this(subPartitionId, 0, 0, size + 2, dataType, isCompressed, size); + } + public BufferHeader( int subPartitionId, int attemptId, @@ -54,6 +59,10 @@ public BufferHeader( this.size = size; } + public int getSubPartitionId() { + return subPartitionId; + } + public Buffer.DataType getDataType() { return dataType; } diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java index d0a757f19c3..76a6c2ef7aa 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java @@ -33,7 +33,15 @@ import org.apache.celeborn.plugin.flink.utils.Utils; import org.apache.celeborn.reflect.DynMethods; -/** Harness used to pack multiple partial buffers together as a full one. */ +/** + * Harness used to pack multiple partial buffers together as a full one. There are two Flink + * integration strategies: Remote Shuffle Service and Hybrid Shuffle. In Remote Shuffle Service + * integration strategy, the {@link BufferPacker} will receive buffers containing both shuffle data + * and the Celeborn header. In Hybrid Shuffle integration strategy employs the subclass {@link + * ReceivedNoHeaderBufferPacker}, which receives buffers containing only shuffle data. In these two + * integration strategies, the BufferPacker must utilize different methods to pack buffers, and the + * result of the packed buffer should be same. + */ public class BufferPacker { private static Logger logger = LoggerFactory.getLogger(BufferPacker.class); @@ -41,14 +49,15 @@ public interface BiConsumerWithException { void accept(T var1, U var2) throws E; } - private final BiConsumerWithException ripeBufferHandler; + protected final BiConsumerWithException + ripeBufferHandler; - private Buffer cachedBuffer; + protected Buffer cachedBuffer; - private int currentSubIdx = -1; + protected int currentSubIdx = -1; public BufferPacker( - BiConsumerWithException ripeBufferHandler) { + BiConsumerWithException ripeBufferHandler) { this.ripeBufferHandler = ripeBufferHandler; } @@ -71,7 +80,8 @@ public void process(Buffer buffer, int subIdx) throws InterruptedException { int targetSubIdx = currentSubIdx; currentSubIdx = subIdx; logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes()); - handleRipeBuffer(dumpedBuffer, targetSubIdx); + handleRipeBuffer( + dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed()); } else { /** * this is an optimization. if cachedBuffer can contain other buffer, then other buffer can @@ -95,12 +105,13 @@ public void process(Buffer buffer, int subIdx) throws InterruptedException { cachedBuffer = buffer; logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes()); - handleRipeBuffer(dumpedBuffer, currentSubIdx); + handleRipeBuffer( + dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed()); } } } - private void logBufferPack(boolean isDrain, Buffer.DataType dataType, int length) { + protected void logBufferPack(boolean isDrain, Buffer.DataType dataType, int length) { logger.debug( "isDrain:{}, cachedBuffer pack partition:{} type:{}, length:{}", isDrain, @@ -112,15 +123,27 @@ private void logBufferPack(boolean isDrain, Buffer.DataType dataType, int length public void drain() throws InterruptedException { if (cachedBuffer != null) { logBufferPack(true, cachedBuffer.getDataType(), cachedBuffer.readableBytes()); - handleRipeBuffer(cachedBuffer, currentSubIdx); + handleRipeBuffer( + cachedBuffer, currentSubIdx, cachedBuffer.getDataType(), cachedBuffer.isCompressed()); } cachedBuffer = null; currentSubIdx = -1; } - private void handleRipeBuffer(Buffer buffer, int subIdx) throws InterruptedException { + protected void handleRipeBuffer( + Buffer buffer, int subIdx, Buffer.DataType dataType, boolean isCompressed) + throws InterruptedException { + // Always set the compress flag to false, because the result buffer generated by {@link + // BufferPacker} needs to be split into multiple buffers in unpack process, + // If the compress flag is set to true for this result buffer, it will throw an exception during + // the unpack process, as compressed buffer cannot be sliced. buffer.setCompressed(false); - ripeBufferHandler.accept(buffer.asByteBuf(), subIdx); + ripeBufferHandler.accept( + buffer.asByteBuf(), new BufferHeader(subIdx, dataType, isCompressed, buffer.getSize())); + } + + public boolean isEmpty() { + return cachedBuffer == null; } public void close() { diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java new file mode 100644 index 00000000000..09337ec4f72 --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.buffer; + +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.celeborn.plugin.flink.utils.BufferUtils; + +/** + * Harness used to pack multiple partial buffers together as a full one. It used in Flink hybrid + * shuffle integration strategy now. + */ +public class ReceivedNoHeaderBufferPacker extends BufferPacker { + + /** The flink buffer header of cached first buffer. */ + private BufferHeader firstBufferHeader; + + public ReceivedNoHeaderBufferPacker( + BiConsumerWithException ripeBufferHandler) { + super(ripeBufferHandler); + } + + @Override + public void process(Buffer buffer, int subIdx) throws InterruptedException { + if (buffer == null) { + return; + } + + if (buffer.readableBytes() == 0) { + buffer.recycleBuffer(); + return; + } + + if (cachedBuffer == null) { + // cache the first buffer and record flink buffer header of first buffer + cachedBuffer = buffer; + currentSubIdx = subIdx; + firstBufferHeader = + new BufferHeader(subIdx, buffer.getDataType(), buffer.isCompressed(), buffer.getSize()); + } else if (currentSubIdx != subIdx) { + // drain the previous cached buffer and cache current buffer + Buffer dumpedBuffer = cachedBuffer; + cachedBuffer = buffer; + int targetSubIdx = currentSubIdx; + currentSubIdx = subIdx; + logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes()); + handleRipeBuffer( + dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed()); + firstBufferHeader = + new BufferHeader(subIdx, buffer.getDataType(), buffer.isCompressed(), buffer.getSize()); + } else { + int bufferHeaderLength = BufferUtils.HEADER_LENGTH - BufferUtils.HEADER_LENGTH_PREFIX; + if (cachedBuffer.readableBytes() + buffer.readableBytes() + bufferHeaderLength + <= cachedBuffer.getMaxCapacity() - BufferUtils.HEADER_LENGTH) { + // if the cache buffer can contain the current buffer, then pack the current buffer into + // cache buffer + ByteBuf byteBuf = cachedBuffer.asByteBuf(); + byteBuf.writeByte(buffer.getDataType().ordinal()); + byteBuf.writeBoolean(buffer.isCompressed()); + byteBuf.writeInt(buffer.getSize()); + byteBuf.writeBytes(buffer.asByteBuf(), 0, buffer.readableBytes()); + logBufferPack(false, buffer.getDataType(), buffer.readableBytes() + bufferHeaderLength); + + buffer.recycleBuffer(); + } else { + // if the cache buffer cannot contain the current buffer, drain the cached buffer, and cache + // the current buffer + Buffer dumpedBuffer = cachedBuffer; + cachedBuffer = buffer; + logBufferPack(false, dumpedBuffer.getDataType(), dumpedBuffer.readableBytes()); + + handleRipeBuffer( + dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(), dumpedBuffer.isCompressed()); + firstBufferHeader = + new BufferHeader(subIdx, buffer.getDataType(), buffer.isCompressed(), buffer.getSize()); + } + } + } + + @Override + protected void handleRipeBuffer( + Buffer buffer, int subIdx, Buffer.DataType dataType, boolean isCompressed) + throws InterruptedException { + if (buffer == null || buffer.readableBytes() == 0) { + return; + } + // Always set the compress flag to false, because this buffer contains Celeborn header and + // multiple flink data buffers. + // It is crucial to keep this flag set to false because we need to slice this buffer to extract + // flink data buffers + // during the unpacking process, the flink {@link NetworkBuffer} cannot correctly slice + // compressed buffer. + buffer.setCompressed(false); + ripeBufferHandler.accept(buffer.asByteBuf(), firstBufferHeader); + } +} diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java index 14599e47722..999d1eb106d 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java @@ -59,6 +59,26 @@ public static void setCompressedDataWithHeader(Buffer buffer, Buffer compressedB buffer.setSize(dataLength + HEADER_LENGTH); } + /** + * It is utilized in Hybrid Shuffle integration strategy, in this case the buffer containing data + * only. Copies the data of the compressed buffer to the origin buffer. + */ + public static void setCompressedDataWithoutHeader(Buffer buffer, Buffer compressedBuffer) { + checkArgument(buffer != null, "Must be not null."); + checkArgument(buffer.getReaderIndex() == 0, "Illegal reader index."); + + boolean isCompressed = compressedBuffer != null && compressedBuffer.isCompressed(); + int dataLength = isCompressed ? compressedBuffer.readableBytes() : buffer.readableBytes(); + ByteBuf byteBuf = buffer.asByteBuf(); + if (isCompressed) { + byteBuf.writerIndex(0); + byteBuf.writeBytes(compressedBuffer.asByteBuf()); + // set the compression flag here, as we need it when writing the sub-header of this buffer + buffer.setCompressed(true); + } + buffer.setSize(dataLength); + } + public static void setBufferHeader( ByteBuf byteBuf, Buffer.DataType dataType, boolean isCompressed, int dataLength) { byteBuf.writerIndex(0); diff --git a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java index 2d5d5e78fd7..8f3c0ce6eee 100644 --- a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java +++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java @@ -23,20 +23,32 @@ import static org.junit.Assert.assertTrue; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import org.apache.commons.lang3.tuple.Pair; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.apache.celeborn.plugin.flink.buffer.BufferHeader; import org.apache.celeborn.plugin.flink.buffer.BufferPacker; +import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker; import org.apache.celeborn.plugin.flink.utils.BufferUtils; +@RunWith(Parameterized.class) public class BufferPackSuiteJ { private static final int BUFFER_SIZE = 20 + 16; @@ -44,6 +56,18 @@ public class BufferPackSuiteJ { private BufferPool bufferPool; + private boolean bufferPackerReceivedBufferHasHeader; + + public BufferPackSuiteJ(boolean bufferPackerReceivedBufferHasHeader) { + this.bufferPackerReceivedBufferHasHeader = bufferPackerReceivedBufferHasHeader; + } + + @Parameterized.Parameters + public static Collection prepareData() { + Object[][] object = {{true}, {false}}; + return Arrays.asList(object); + } + @Before public void setup() throws Exception { networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE); @@ -66,13 +90,14 @@ public void testPackEmptyBuffers() throws Exception { Integer subIdx = 2; List output = new ArrayList<>(); - BufferPacker.BiConsumerWithException ripeBufferHandler = - (ripe, sub) -> { - assertEquals(subIdx, sub); - output.add(ripe); - }; - - BufferPacker packer = new BufferPacker(ripeBufferHandler); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> { + assertEquals(subIdx, Integer.valueOf(header.getSubPartitionId())); + output.add(ripe); + }; + + BufferPacker packer = createBufferPakcer(ripeBufferHandler); packer.process(buffers.get(0), subIdx); packer.process(buffers.get(1), subIdx); packer.process(buffers.get(2), subIdx); @@ -89,9 +114,12 @@ public void testPartialBuffersForSameSubIdx() throws Exception { setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); List> output = new ArrayList<>(); - BufferPacker.BiConsumerWithException ripeBufferHandler = - (ripe, sub) -> output.add(Pair.of(ripe, sub)); - BufferPacker packer = new BufferPacker(ripeBufferHandler); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> + output.add( + Pair.of(addBufferHeaderPossible(ripe, header), header.getSubPartitionId())); + BufferPacker packer = createBufferPakcer(ripeBufferHandler); fillBuffers(buffers, 0, 1, 2); packer.process(buffers.get(0), 2); @@ -123,9 +151,12 @@ public void testPartialBuffersForMultipleSubIdx() throws Exception { setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); List> output = new ArrayList<>(); - BufferPacker.BiConsumerWithException ripeBufferHandler = - (ripe, sub) -> output.add(Pair.of(ripe, sub)); - BufferPacker packer = new BufferPacker(ripeBufferHandler); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> + output.add( + Pair.of(addBufferHeaderPossible(ripe, header), header.getSubPartitionId())); + BufferPacker packer = createBufferPakcer(ripeBufferHandler); fillBuffers(buffers, 0, 1, 2); packer.process(buffers.get(0), 0); @@ -158,9 +189,12 @@ public void testUnpackedBuffers() throws Exception { setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER); List> output = new ArrayList<>(); - BufferPacker.BiConsumerWithException ripeBufferHandler = - (ripe, sub) -> output.add(Pair.of(ripe, sub)); - BufferPacker packer = new BufferPacker(ripeBufferHandler); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> + output.add( + Pair.of(addBufferHeaderPossible(ripe, header), header.getSubPartitionId())); + BufferPacker packer = createBufferPakcer(ripeBufferHandler); fillBuffers(buffers, 0, 1, 2); packer.process(buffers.get(0), 0); @@ -186,6 +220,59 @@ public void testUnpackedBuffers() throws Exception { unpacked.forEach(Buffer::recycleBuffer); } + @Test + public void testPackMultipleBuffers() throws Exception { + int numBuffers = 7; + List buffers = new ArrayList<>(); + buffers.add(buildSomeBuffer(100)); + buffers.addAll(requestBuffers(numBuffers - 1)); + setCompressed(buffers, true, true, true, false, false, false, true); + setDataType( + buffers, + EVENT_BUFFER, + DATA_BUFFER, + DATA_BUFFER, + EVENT_BUFFER, + DATA_BUFFER, + DATA_BUFFER, + EVENT_BUFFER); + + List> output = new ArrayList<>(); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> + output.add( + Pair.of(addBufferHeaderPossible(ripe, header), header.getSubPartitionId())); + BufferPacker packer = createBufferPakcer(ripeBufferHandler); + fillBuffers(buffers, 0, 1, 2, 3, 4, 5, 6, 7); + + for (int i = 0; i < buffers.size(); i++) { + packer.process(buffers.get(i), 0); + } + packer.drain(); + + List unpacked = new ArrayList<>(); + for (int i = 0; i < output.size(); i++) { + Pair pair = output.get(i); + assertEquals(Integer.valueOf(0), pair.getRight()); + unpacked.addAll(BufferPacker.unpack(pair.getLeft())); + } + assertEquals(7, unpacked.size()); + + checkIfCompressed(unpacked, true, true, true, false, false, false, true); + checkDataType( + unpacked, + EVENT_BUFFER, + DATA_BUFFER, + DATA_BUFFER, + EVENT_BUFFER, + DATA_BUFFER, + DATA_BUFFER, + EVENT_BUFFER); + verifyBuffers(unpacked, 0, 1, 2, 3, 4, 5, 6, 7); + unpacked.forEach(Buffer::recycleBuffer); + } + @Test public void testFailedToHandleRipeBufferAndClose() throws Exception { List buffers = requestBuffers(1); @@ -193,12 +280,13 @@ public void testFailedToHandleRipeBufferAndClose() throws Exception { setDataType(buffers, DATA_BUFFER); fillBuffers(buffers, 0); - BufferPacker.BiConsumerWithException ripeBufferHandler = - (ripe, sub) -> { - // ripe.release(); - throw new RuntimeException("Test"); - }; - BufferPacker packer = new BufferPacker(ripeBufferHandler); + BufferPacker.BiConsumerWithException + ripeBufferHandler = + (ripe, header) -> { + // ripe.release(); + throw new RuntimeException("Test"); + }; + BufferPacker packer = createBufferPakcer(ripeBufferHandler); System.out.println(buffers.get(0).refCnt()); packer.process(buffers.get(0), 0); try { @@ -248,8 +336,17 @@ private void fillBuffers(List buffers, int... ints) { for (int i = 0; i < buffers.size(); i++) { Buffer buffer = buffers.get(i); ByteBuf target = buffer.asByteBuf(); - BufferUtils.setBufferHeader(target, buffer.getDataType(), buffer.isCompressed(), 4); - target.writerIndex(BufferUtils.HEADER_LENGTH); + + if (bufferPackerReceivedBufferHasHeader) { + // If the buffer includes a header, we need to leave space for the header, so we should + // update the writer index to BufferUtils.HEADER_LENGTH. + BufferUtils.setBufferHeader(target, buffer.getDataType(), buffer.isCompressed(), 4); + target.writerIndex(BufferUtils.HEADER_LENGTH); + } else { + // if the buffer does not have a header, we can directly write data starting from the + // beginning of the buffer. + target.writerIndex(0); + } target.writeInt(ints[i]); } } @@ -260,4 +357,51 @@ private void verifyBuffers(List buffers, int... expects) { assertEquals(expects[i], actual.getInt(0)); } } + + public static Buffer buildSomeBuffer(int size) { + final MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(size); + return new NetworkBuffer(seg, MemorySegment::free, Buffer.DataType.DATA_BUFFER, size); + } + + public ByteBuf addBufferHeaderPossible(ByteBuf byteBuf, BufferHeader bufferHeader) { + // Try to add buffer header if bufferPackerReceivedBufferHasHeader set to false in BufferPacker + // drain process + if (bufferPackerReceivedBufferHasHeader) { + return byteBuf; + } + + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + // create a small buffer headerBuf to write the buffer header + ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH); + + // write celeborn buffer header (subpartitionid(4) + attemptId(4) + nextBatchId(4) + + // compressedsize) + headerBuf.writeInt(bufferHeader.getSubPartitionId()); + headerBuf.writeInt(0); + headerBuf.writeInt(0); + headerBuf.writeInt( + byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH - BufferUtils.HEADER_LENGTH_PREFIX)); + + // write flink buffer header (dataType(1) + isCompress(1) + size(4)) + headerBuf.writeByte(bufferHeader.getDataType().ordinal()); + headerBuf.writeBoolean(bufferHeader.isCompressed()); + headerBuf.writeInt(bufferHeader.getSize()); + + // composite the headerBuf and data buffer together + compositeByteBuf.addComponents(true, headerBuf, byteBuf); + ByteBuf packedByteBuf = Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer()); + byteBuf.writerIndex(0); + byteBuf.writeBytes(packedByteBuf, 0, packedByteBuf.readableBytes()); + return byteBuf; + } + + public BufferPacker createBufferPakcer( + BufferPacker.BiConsumerWithException + ripeBufferHandler) { + if (bufferPackerReceivedBufferHasHeader) { + return new BufferPacker(ripeBufferHandler); + } else { + return new ReceivedNoHeaderBufferPacker(ripeBufferHandler); + } + } } diff --git a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java index 326a1198521..02306a5adc6 100644 --- a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java +++ b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java @@ -101,8 +101,16 @@ public TierProducerAgent createProducerAgent( ScheduledExecutorService ioExecutor, List shuffleDescriptors, int maxRequestedBuffers) { - // TODO impl this in the follow-up PR. - return null; + return new CelebornTierProducerAgent( + conf, + partitionId, + numPartitions, + numSubpartitions, + NUM_BYTES_PER_SEGMENT, + bufferSizeBytes, + storageMemoryManager, + resourceRegistry, + shuffleDescriptors); } @Override diff --git a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java new file mode 100644 index 00000000000..aab2b3ae54d --- /dev/null +++ b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.tiered; + +import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument; +import static org.apache.celeborn.plugin.flink.utils.Utils.checkState; +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.EndOfSegmentEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.util.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.exception.DriverChangedException; +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.plugin.flink.buffer.BufferHeader; +import org.apache.celeborn.plugin.flink.buffer.BufferPacker; +import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker; +import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl; +import org.apache.celeborn.plugin.flink.utils.BufferUtils; +import org.apache.celeborn.plugin.flink.utils.Utils; + +public class CelebornTierProducerAgent implements TierProducerAgent { + + private static final Logger LOG = LoggerFactory.getLogger(CelebornTierProducerAgent.class); + + private final int numBuffersPerSegment; + + private final int bufferSizeBytes; + + private final int numPartitions; + + private final int numSubPartitions; + + private final CelebornConf celebornConf; + + private final TieredStorageMemoryManager memoryManager; + + private final String applicationId; + + private final int shuffleId; + + private final int mapId; + + private final int attemptId; + + private final int partitionId; + + private final String lifecycleManagerHost; + + private final int lifecycleManagerPort; + + private final long lifecycleManagerTimestamp; + + private FlinkShuffleClientImpl flinkShuffleClient; + + private BufferPacker bufferPacker; + + private final int[] subPartitionSegmentIds; + + private final int[] subPartitionSegmentBuffers; + + private final int maxReviveTimes; + + private PartitionLocation partitionLocation; + + private boolean hasRegisteredShuffle; + + private int currentRegionIndex = 0; + + private int currentSubpartition = 0; + + private boolean hasSentHandshake = false; + + private boolean hasSentRegionStart = false; + + private volatile boolean isReleased; + + CelebornTierProducerAgent( + CelebornConf conf, + TieredStoragePartitionId partitionId, + int numPartitions, + int numSubPartitions, + int numBytesPerSegment, + int bufferSizeBytes, + TieredStorageMemoryManager memoryManager, + TieredStorageResourceRegistry resourceRegistry, + List shuffleDescriptors) { + checkArgument( + numBytesPerSegment >= bufferSizeBytes, "One segment should contain at least one buffer."); + checkArgument(shuffleDescriptors.size() == 1, "There should be only one shuffle descriptor."); + TierShuffleDescriptor descriptor = shuffleDescriptors.get(0); + checkArgument( + descriptor instanceof TierShuffleDescriptorImpl, + "Wrong shuffle descriptor type " + descriptor.getClass()); + TierShuffleDescriptorImpl shuffleDesc = (TierShuffleDescriptorImpl) descriptor; + + this.numBuffersPerSegment = numBytesPerSegment / bufferSizeBytes; + this.bufferSizeBytes = bufferSizeBytes; + this.memoryManager = memoryManager; + this.numPartitions = numPartitions; + this.numSubPartitions = numSubPartitions; + this.celebornConf = conf; + this.subPartitionSegmentIds = new int[numSubPartitions]; + this.subPartitionSegmentBuffers = new int[numSubPartitions]; + this.maxReviveTimes = conf.clientPushMaxReviveTimes(); + + this.applicationId = shuffleDesc.getCelebornAppId(); + this.shuffleId = + shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getShuffleId(); + this.mapId = shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getMapId(); + this.attemptId = + shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getAttemptId(); + this.partitionId = + shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getPartitionId(); + this.lifecycleManagerHost = shuffleDesc.getShuffleResource().getLifecycleManagerHost(); + this.lifecycleManagerPort = shuffleDesc.getShuffleResource().getLifecycleManagerPort(); + this.lifecycleManagerTimestamp = + shuffleDesc.getShuffleResource().getLifecycleManagerTimestamp(); + this.flinkShuffleClient = getShuffleClient(); + + Arrays.fill(subPartitionSegmentIds, -1); + Arrays.fill(subPartitionSegmentBuffers, 0); + + this.bufferPacker = new ReceivedNoHeaderBufferPacker(this::write); + resourceRegistry.registerResource(partitionId, this::releaseResources); + registerShuffle(); + try { + handshake(); + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + } + + @Override + public boolean tryStartNewSegment( + TieredStorageSubpartitionId tieredStorageSubpartitionId, int segmentId, int minNumBuffers) { + int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId(); + checkState( + segmentId >= subPartitionSegmentIds[subPartitionId], "Wrong segment id " + segmentId); + subPartitionSegmentIds[subPartitionId] = segmentId; + // If the start segment rpc is sent, the worker side will expect that + // there must be at least one buffer will be written in the next moment. + try { + flinkShuffleClient.segmentStart( + shuffleId, mapId, attemptId, subPartitionId, segmentId, partitionLocation); + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + return true; + } + + @Override + public boolean tryWrite( + TieredStorageSubpartitionId tieredStorageSubpartitionId, + Buffer buffer, + Object bufferOwner, + int numRemainingConsecutiveBuffers) { + // It should be noted that, unlike RemoteShuffleOutputGate#write, the received buffer contains + // only + // and does not have any remaining space for writing the celeborn header. + + int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId(); + + if (subPartitionSegmentBuffers[subPartitionId] + 1 + numRemainingConsecutiveBuffers + >= numBuffersPerSegment) { + // End the current segment if the segment buffer count reaches the threshold + subPartitionSegmentBuffers[subPartitionId] = 0; + try { + bufferPacker.drain(); + } catch (InterruptedException e) { + buffer.recycleBuffer(); + ExceptionUtils.rethrow(e, "Failed to process buffer."); + } + appendEndOfSegmentBuffer(subPartitionId); + return false; + } + + if (buffer.isBuffer()) { + memoryManager.transferBufferOwnership( + bufferOwner, CelebornTierFactory.getCelebornTierName(), buffer); + } + + // write buffer to BufferPacker and record buffer count per subPartition per segment + processBuffer(buffer, subPartitionId); + subPartitionSegmentBuffers[subPartitionId]++; + return true; + } + + @Override + public void close() { + if (hasSentRegionStart) { + regionFinish(); + } + try { + if (hasRegisteredShuffle && partitionLocation != null) { + flinkShuffleClient.mapPartitionMapperEnd( + shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId()); + } + } catch (Exception e) { + Utils.rethrowAsRuntimeException(e); + } + bufferPacker.close(); + bufferPacker = null; + flinkShuffleClient.cleanup(shuffleId, mapId, attemptId); + flinkShuffleClient = null; + } + + private void regionStartOrFinish(int subPartitionId) { + // check whether the region should be started or finished + regionStart(); + if (subPartitionId < currentSubpartition) { + // if the consumed subPartitionId is out of order, it means that should the previous region + // should be finished, and starting a new region. + regionFinish(); + LOG.debug( + "Check region finish sub partition id {} and start next region {}", + subPartitionId, + currentRegionIndex); + regionStart(); + } + } + + private void regionStart() { + if (hasSentRegionStart) { + return; + } + regionStartWithRevive(); + } + + private void regionStartWithRevive() { + try { + int remainingReviveTimes = maxReviveTimes; + while (remainingReviveTimes-- > 0 && !hasSentRegionStart) { + Optional revivePartition = + flinkShuffleClient.regionStart( + shuffleId, mapId, attemptId, partitionLocation, currentRegionIndex, false); + if (revivePartition.isPresent()) { + LOG.info( + "Revive at regionStart, currentTimes:{}, totalTimes:{} for shuffleId:{}, mapId:{}, " + + "attempId:{}, currentRegionIndex:{}, isBroadcast:{}, newPartition:{}, oldPartition:{}", + remainingReviveTimes, + maxReviveTimes, + shuffleId, + mapId, + attemptId, + currentRegionIndex, + false, + revivePartition, + partitionLocation); + partitionLocation = revivePartition.get(); + // For every revive partition, handshake should be sent firstly + hasSentHandshake = false; + handshake(); + if (numSubPartitions > 0) { + for (int i = 0; i < numSubPartitions; i++) { + flinkShuffleClient.segmentStart( + shuffleId, mapId, attemptId, i, subPartitionSegmentIds[i], partitionLocation); + } + } + } else { + hasSentRegionStart = true; + currentSubpartition = 0; + } + } + if (remainingReviveTimes == 0 && !hasSentRegionStart) { + throw new RuntimeException( + "After retry " + maxReviveTimes + " times, still failed to send regionStart"); + } + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + } + + void regionFinish() { + try { + bufferPacker.drain(); + flinkShuffleClient.regionFinish(shuffleId, mapId, attemptId, partitionLocation); + hasSentRegionStart = false; + currentRegionIndex++; + } catch (Exception e) { + Utils.rethrowAsRuntimeException(e); + } + } + + private void handshake() throws IOException { + try { + int remainingReviveTimes = maxReviveTimes; + while (remainingReviveTimes-- > 0 && !hasSentHandshake) { + Optional revivePartition = + flinkShuffleClient.pushDataHandShake( + shuffleId, mapId, attemptId, numSubPartitions, bufferSizeBytes, partitionLocation); + // if remainingReviveTimes == 0 and revivePartition.isPresent(), there is no need to send + // handshake again + if (revivePartition.isPresent() && remainingReviveTimes > 0) { + LOG.info( + "Revive at handshake, currentTimes:{}, totalTimes:{} for shuffleId:{}, mapId:{}, " + + "attempId:{}, currentRegionIndex:{}, newPartition:{}, oldPartition:{}", + remainingReviveTimes, + maxReviveTimes, + shuffleId, + mapId, + attemptId, + currentRegionIndex, + revivePartition, + partitionLocation); + partitionLocation = revivePartition.get(); + hasSentHandshake = false; + } else { + hasSentHandshake = true; + } + } + if (remainingReviveTimes == 0 && !hasSentHandshake) { + throw new RuntimeException( + "After retry " + maxReviveTimes + " times, still failed to send handshake"); + } + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + } + + private void releaseResources() { + if (!isReleased) { + isReleased = true; + } + } + + private void registerShuffle() { + try { + if (!hasRegisteredShuffle) { + partitionLocation = + flinkShuffleClient.registerMapPartitionTask( + shuffleId, numPartitions, mapId, attemptId, partitionId, true); + Utils.checkNotNull(partitionLocation); + hasRegisteredShuffle = true; + } + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + } + + private void write(ByteBuf byteBuf, BufferHeader bufferHeader) { + try { + // create a composite buffer and write a header into it. This composite buffer will serve as + // the result packed buffer. + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH); + + // write celeborn buffer header (subpartitionid(4) + attemptId(4) + nextBatchId(4) + + // compressedsize) + headerBuf.writeInt(bufferHeader.getSubPartitionId()); + headerBuf.writeInt(attemptId); + headerBuf.writeInt(0); + headerBuf.writeInt( + byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH - BufferUtils.HEADER_LENGTH_PREFIX)); + + // write flink buffer header (dataType(1) + isCompress(1) + size(4)) + headerBuf.writeByte(bufferHeader.getDataType().ordinal()); + headerBuf.writeBoolean(bufferHeader.isCompressed()); + headerBuf.writeInt(bufferHeader.getSize()); + + // composite the headerBuf and data buffer together + compositeByteBuf.addComponents(true, headerBuf, byteBuf); + io.netty.buffer.ByteBuf wrappedBuffer = + io.netty.buffer.Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer()); + + int numWritten = + flinkShuffleClient.pushDataToLocation( + shuffleId, + mapId, + attemptId, + bufferHeader.getSubPartitionId(), + wrappedBuffer, + partitionLocation, + compositeByteBuf::release); + checkState( + numWritten == byteBuf.readableBytes() + BufferUtils.HEADER_LENGTH, "Wrong written size."); + } catch (IOException e) { + Utils.rethrowAsRuntimeException(e); + } + } + + private void appendEndOfSegmentBuffer(int subPartitionId) { + try { + checkState(bufferPacker.isEmpty(), "BufferPacker is not empty"); + MemorySegment endSegmentMemorySegment = + MemorySegmentFactory.wrap( + EventSerializer.toSerializedEvent(EndOfSegmentEvent.INSTANCE).array()); + Buffer endOfSegmentBuffer = + new NetworkBuffer( + endSegmentMemorySegment, + FreeingBufferRecycler.INSTANCE, + END_OF_SEGMENT, + endSegmentMemorySegment.size()); + processBuffer(endOfSegmentBuffer, subPartitionId); + } catch (Exception e) { + ExceptionUtils.rethrow(e, "Failed to append end of segment event."); + } + } + + private void processBuffer(Buffer originBuffer, int subPartitionId) { + try { + regionStartOrFinish(subPartitionId); + currentSubpartition = subPartitionId; + + Buffer buffer = originBuffer; + if (originBuffer.isCompressed()) { + // In flink 1.20.0, it will receive a compressed buffer. However, since we need to write + // data to this buffer and the compressed buffer is read-only, + // we must create a new Buffer object to the wrap origin buffer. + NetworkBuffer networkBuffer = + new NetworkBuffer( + originBuffer.getMemorySegment(), + originBuffer.getRecycler(), + originBuffer.getDataType(), + originBuffer.getSize()); + networkBuffer.writerIndex(originBuffer.asByteBuf().writerIndex()); + buffer = networkBuffer; + } + + // TODO: To enhance performance, the flink should pass an no-compressed buffer to producer + // agent and we compress the buffer here + + // set the buffer meta + BufferUtils.setCompressedDataWithoutHeader(buffer, originBuffer); + + bufferPacker.process(buffer, subPartitionId); + } catch (InterruptedException e) { + originBuffer.recycleBuffer(); + ExceptionUtils.rethrow(e, "Failed to process buffer."); + } + } + + @VisibleForTesting + FlinkShuffleClientImpl getShuffleClient() { + try { + return FlinkShuffleClientImpl.get( + applicationId, + lifecycleManagerHost, + lifecycleManagerPort, + lifecycleManagerTimestamp, + celebornConf, + null); + } catch (DriverChangedException e) { + // would generate a new attempt to retry output gate + throw new RuntimeException(e.getMessage()); + } + } +}