diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index fdee16b72..a5baedda5 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -356,25 +356,38 @@ def forward_unix_socket(self, unixsocket, localport=None): yield localport @Driver.check_active - @step(args=['src', 'dst']) - def scp(self, *, src, dst): + @step(args=['src', 'dst', 'recursive']) + def scp(self, *, src, dst, recursive=False): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") - if src.startswith(':') == dst.startswith(':'): + if not isinstance(src, list): + src = [src] + + # take Path like objects into account + src = [str(f) for f in src] + dst = str(dst) + + remote_src = [f.startswith(':') for f in src] + if any(remote_src) != all(remote_src): + raise ValueError("All sources must be consistently local or remote (start with :)") + + if all(remote_src) == dst.startswith(':'): raise ValueError("Either source or destination must be remote (start with :)") - if src.startswith(':'): - src = '_' + src - if dst.startswith(':'): - dst = '_' + dst + + src = [s.replace(':', '_:') for s in src] + dst = dst.replace(':', '_:') complete_cmd = [self._scp, "-S", self._ssh, "-F", "none", "-o", f"ControlPath={self.control.replace('%', '%%')}", - src, dst, + *src, + dst, ] - + + if recursive: + complete_cmd.insert(1, "-r") if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode(): complete_cmd.insert(1, "-s") if self.explicit_scp_mode and self._scp_supports_explicit_scp_mode(): @@ -594,3 +607,4 @@ def _stop_keepalive(self): if stdout: for line in stdout.splitlines(): self.logger.warning("Keepalive %s: %s", self.networkservice.address, line) + diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index 58b2720ec..233a5f729 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -1306,7 +1306,7 @@ def ssh(self): def scp(self): drv = self._get_ssh() - res = drv.scp(src=self.args.src, dst=self.args.dst) + res = drv.scp(src=self.args.src, dst=self.args.dst, recursive=self.args.recursive) if res: raise InteractiveCommandError("scp error", res) @@ -2009,8 +2009,9 @@ def get_parser(auto_doc_mode=False) -> "argparse.ArgumentParser | AutoProgramArg subparser = subparsers.add_parser("scp", help="transfer file via scp") subparser.add_argument("--name", "-n", help="optional resource name") - subparser.add_argument("src", help="source path (use :dir/file for remote side)") + subparser.add_argument("src", nargs="+", help="source path (use :dir/file for remote side)") subparser.add_argument("dst", help="destination path (use :dir/file for remote side)") + subparser.add_argument("--recursive", "-r", action="store_true", help="copy recursive") subparser.set_defaults(func=ClientSession.scp) subparser = subparsers.add_parser( diff --git a/man/labgrid-client.1 b/man/labgrid-client.1 index fdc62be72..fe77f2c0b 100644 --- a/man/labgrid-client.1 +++ b/man/labgrid-client.1 @@ -678,7 +678,7 @@ transfer file via scp .INDENT 3.5 .sp .EX -usage: labgrid\-client scp [\-\-name NAME] src dst +usage: labgrid\-client scp [\-\-name NAME] [\-\-recursive] src [src ...] dst .EE .UNINDENT .UNINDENT @@ -697,6 +697,11 @@ destination path (use :dir/file for remote side) .B \-\-name , \-n optional resource name .UNINDENT +.INDENT 0.0 +.TP +.B \-\-recursive, \-r +copy recursive +.UNINDENT .SS labgrid\-client sd\-mux .sp switch USB SD Muxer or get current mode diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index 4c233a834..69cfbc38c 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -1,5 +1,6 @@ import pytest import socket +import os from labgrid import Environment from labgrid.driver import SSHDriver, ExecutionError @@ -218,3 +219,157 @@ def test_unix_socket_forward(ssh_localhost, tmpdir): send_socket.send(test_string.encode("utf-8")) assert client_socket.recv(16).decode("utf-8") == test_string + + +@pytest.mark.sshusername +def test_local_scp_to(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + + os.mkdir(l_dir) + os.mkdir(r_dir) + + magic = ["FOObar 1337 scp-to"] + name = "test_scp-to.txt" + + file = l_dir.join(name) + open(file, 'x').writelines(magic) + + ssh_localhost.scp(src=f'{file}', dst=f':{r_dir}') + assert open(r_dir.join(name), 'r').readlines() == magic + + +@pytest.mark.sshusername +def test_local_scp_from(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + + os.mkdir(l_dir) + os.mkdir(r_dir) + + magic = ["FOObar 1337 scp-to"] + name = 'test_scp-from.txt' + + file = r_dir.join(name) + open(file, 'x').writelines(magic) + + ssh_localhost.scp(src=f':{file}', dst=f'{l_dir}') + assert open(l_dir.join(name), 'r').readlines() == magic + + +@pytest.mark.sshusername +def test_local_scp_to_multi(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + n_files = 13 + + magics = [[f"FOObar 1337 scp-to_{i}"] for i in range(n_files)] + names = [f"test_scp-to_{i}.txt" for i in range(n_files)] + + files = [str(l_dir.join(name)) for name in names] + for i in range(n_files): + open(files[i], 'x').writelines(magics[i]) + + ssh_localhost.scp(src=files, dst=f':{r_dir}') + + for i in range(n_files): + assert open(r_dir.join(names[i]), 'r').readlines() == magics[i] + + +@pytest.mark.sshusername +def test_local_scp_from_multi(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + n_files = 13 + + magics = [[f"FOObar 1337 scp-from_{i}"] for i in range(n_files)] + names = [f"test_scp-from_{i}.txt" for i in range(n_files)] + + files = [str(r_dir.join(name)) for name in names] + for i in range(n_files): + open(files[i], 'x').writelines(magics[i]) + + ssh_localhost.scp(src=[f":{f}" for f in files], dst=f'{l_dir}') + + for i in range(n_files): + assert open(l_dir.join(names[i]), 'r').readlines() == magics[i] + + +@pytest.mark.sshusername +def test_local_scp_to_recursive(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + n_files = 13 + + magics = [[f"FOObar 1337 scp-to_{i}"] for i in range(n_files)] + names = [f"test_scp-to_{i}.txt" for i in range(n_files)] + + files = [str(l_dir.join(name)) for name in names] + for i in range(n_files): + open(files[i], 'x').writelines(magics[i]) + + ssh_localhost.scp(src=f"{l_dir}", dst=f':{r_dir}', recursive=True) + + for i in range(n_files): + assert open(r_dir.join("local").join(names[i]), 'r').readlines() == magics[i] + + +@pytest.mark.sshusername +def test_local_scp_from_recursive(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + n_files = 13 + + magics = [[f"FOObar 1337 scp-from_{i}"] for i in range(n_files)] + names = [f"test_scp-from_{i}.txt" for i in range(n_files)] + + files = [str(r_dir.join(name)) for name in names] + for i in range(n_files): + open(files[i], 'x').writelines(magics[i]) + + ssh_localhost.scp(src=f":{r_dir}", dst=f'{l_dir}', recursive=True) + + for i in range(n_files): + assert open(l_dir.join("remote").join(names[i]), 'r').readlines() == magics[i] + + +@pytest.mark.sshusername +def test_local_scp_none_remote(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + try: + ssh_localhost.scp(src=l_dir, dst=r_dir) + except ValueError: + return + + assert False + + +@pytest.mark.sshusername +def test_local_scp_both_remote(ssh_localhost, tmpdir): + l_dir = tmpdir.join("local") + r_dir = tmpdir.join("remote") + os.mkdir(l_dir) + os.mkdir(r_dir) + + try: + ssh_localhost.scp(src=f":{l_dir}", dst=f":{r_dir}") + except ValueError: + return + + assert False