From 8db71d71d2458ec394b4aacf3b360b8151f7446e Mon Sep 17 00:00:00 2001 From: Martijn Vegter Date: Sun, 14 Jan 2024 15:54:12 +0100 Subject: [PATCH] ChannelSftp cannot download directories, it results in a zero byte file, and should thus fail consistently The file retrievel with a user provided destination as String already validates if the source file is a directory and fails accordingly. This change is for the user provided OutputStream flow where this validation was missing. --- .../java/com/jcraft/jsch/ChannelSftp.java | 6 +- .../java/com/jcraft/jsch/SftpRetrievalIT.java | 133 ++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 src/test/java/com/jcraft/jsch/SftpRetrievalIT.java diff --git a/src/main/java/com/jcraft/jsch/ChannelSftp.java b/src/main/java/com/jcraft/jsch/ChannelSftp.java index 4848339e..6f19969a 100644 --- a/src/main/java/com/jcraft/jsch/ChannelSftp.java +++ b/src/main/java/com/jcraft/jsch/ChannelSftp.java @@ -975,8 +975,12 @@ public void get(String src, OutputStream dst, SftpProgressMonitor monitor, int m src = remoteAbsolutePath(src); src = isUnique(src); + SftpATTRS attr = _stat(src); + if (attr.isDir()) { + throw new SftpException(SSH_FX_FAILURE, "not supported to get directory " + src); + } + if (monitor != null) { - SftpATTRS attr = _stat(src); monitor.init(SftpProgressMonitor.GET, src, "??", attr.getSize()); if (mode == RESUME) { monitor.count(skip); diff --git a/src/test/java/com/jcraft/jsch/SftpRetrievalIT.java b/src/test/java/com/jcraft/jsch/SftpRetrievalIT.java new file mode 100644 index 00000000..ff1ea161 --- /dev/null +++ b/src/test/java/com/jcraft/jsch/SftpRetrievalIT.java @@ -0,0 +1,133 @@ +package com.jcraft.jsch; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.FileOutputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Base64; +import java.util.List; +import java.util.Locale; +import java.util.Random; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.images.builder.ImageFromDockerfile; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +@Testcontainers +class SftpRetrievalIT { + + private static final int timeout = 10000; + + @TempDir + public Path tmpDir; + private Path in; + private Path out; + + @Container + public GenericContainer sshd = new GenericContainer<>( + new ImageFromDockerfile().withFileFromClasspath("asyncsshd.py", "docker/asyncsshd.py") + .withFileFromClasspath("ssh_host_ed448_key", "docker/ssh_host_ed448_key") + .withFileFromClasspath("ssh_host_ed448_key.pub", "docker/ssh_host_ed448_key.pub") + .withFileFromClasspath("ssh_host_rsa_key", "docker/ssh_host_rsa_key") + .withFileFromClasspath("ssh_host_rsa_key.pub", "docker/ssh_host_rsa_key.pub") + .withFileFromClasspath("ssh_host_ecdsa256_key", "docker/ssh_host_ecdsa256_key") + .withFileFromClasspath("ssh_host_ecdsa256_key.pub", "docker/ssh_host_ecdsa256_key.pub") + .withFileFromClasspath("ssh_host_ecdsa384_key", "docker/ssh_host_ecdsa384_key") + .withFileFromClasspath("ssh_host_ecdsa384_key.pub", "docker/ssh_host_ecdsa384_key.pub") + .withFileFromClasspath("ssh_host_ecdsa521_key", "docker/ssh_host_ecdsa521_key") + .withFileFromClasspath("ssh_host_ecdsa521_key.pub", "docker/ssh_host_ecdsa521_key.pub") + .withFileFromClasspath("ssh_host_ed25519_key", "docker/ssh_host_ed25519_key") + .withFileFromClasspath("ssh_host_ed25519_key.pub", "docker/ssh_host_ed25519_key.pub") + .withFileFromClasspath("ssh_host_dsa_key", "docker/ssh_host_dsa_key") + .withFileFromClasspath("ssh_host_dsa_key.pub", "docker/ssh_host_dsa_key.pub") + .withFileFromClasspath("authorized_keys", "docker/authorized_keys") + .withFileFromClasspath("Dockerfile", "docker/Dockerfile.asyncssh")) + .withExposedPorts(22); + + @BeforeEach + public void beforeEach() throws Exception { + in = tmpDir.resolve("in"); + out = tmpDir.resolve("out"); + Files.createFile(in); + try (OutputStream os = Files.newOutputStream(in)) { + byte[] data = new byte[1024]; + for (int i = 0; i < 1024 * 100; i += 1024) { + new Random().nextBytes(data); + os.write(data); + } + } + } + + @Test + void testDirectoryRetrievalDestinationAsString() throws Exception { + JSch ssh = createRSAIdentity(); + Session session = createSession(ssh); + + SftpException sftpException = assertThrows(SftpException.class, + () -> doSftp(session, (sftp) -> sftp.get("/root/", out.toString()))); + assertEquals("not supported to get directory /root/", sftpException.getMessage()); + } + + @Test + void testDirectoryRetrievalDestinationAsStream() throws Exception { + JSch ssh = createRSAIdentity(); + Session session = createSession(ssh); + + SftpException sftpException = assertThrows(SftpException.class, + () -> doSftp(session, (sftp) -> sftp.get("/root/", new FileOutputStream(out.toString())))); + assertEquals("not supported to get directory /root/", sftpException.getMessage()); + } + + private JSch createRSAIdentity() throws Exception { + HostKey hostKey = readHostKey(getResourceFile("docker/ssh_host_rsa_key.pub")); + JSch ssh = new JSch(); + ssh.addIdentity(getResourceFile("docker/id_rsa"), getResourceFile("docker/id_rsa.pub"), null); + ssh.getHostKeyRepository().add(hostKey, null); + return ssh; + } + + private Session createSession(JSch ssh) throws Exception { + Session session = ssh.getSession("root", sshd.getHost(), sshd.getFirstMappedPort()); + session.setConfig("StrictHostKeyChecking", "yes"); + session.setConfig("PreferredAuthentications", "publickey"); + return session; + } + + private String getResourceFile(String fileName) { + return ResourceUtil.getResourceFile(getClass(), fileName); + } + + private HostKey readHostKey(String fileName) throws Exception { + List lines = Files.readAllLines(Paths.get(fileName), UTF_8); + String[] split = lines.get(0).split("\\s+"); + String hostname = + String.format(Locale.ROOT, "[%s]:%d", sshd.getHost(), sshd.getFirstMappedPort()); + return new HostKey(hostname, Base64.getDecoder().decode(split[1])); + } + + private void doSftp(Session session, ThrowingConsumer method) + throws Exception { + session.setTimeout(timeout); + session.connect(); + ChannelSftp sftp = (ChannelSftp) session.openChannel("sftp"); + sftp.connect(timeout); + sftp.put(in.toString(), "/root/test"); + method.accept(sftp); + sftp.disconnect(); + session.disconnect(); + } + + @FunctionalInterface + public interface ThrowingConsumer { + void accept(T t) throws E; + } + +}