#!/usr/bin/python3
#
# Univention SSH
#  ssh wrapper script
#
# SPDX-FileCopyrightText: 2004-2026 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

import os
import shlex
import signal
import sys
import time
from textwrap import dedent

from univention.config_registry import ucr


def display_help(out=sys.stdout):
    """Display help text."""
    if cmd == 'univention-ssh':
        text = '''
        %(cmd)s: internal managing tool for UCS clients
        Copyright (c) 2001-2026 Univention GmbH, Germany

        Syntax:
          %(cmd)s [-timeout <seconds>] [--no-split] <pwd_file> [ssh-options] [user@]<host> <command> [<arguments>...]
          %(cmd)s [--help] [--version]

        Parameters:
          pwd_file                         file containing the password
          options                          optional options to %(cmd)s
          ssh-options                      additional options to ssh
          user                             user to log in as on the remote system
          host                             name or ip address of the remote system
          command                          command to be executed on remote system
          arguments                        additional arguments to remote command

        Options:
          -timeout SECONDS:                Specify a timeout in seconds, after
                                           which the ssh connection is forcefully terminated.
          --no-split:                      don't word-split command
          -h | --help | -?:                print this usage message
          --version:                       print version information

        Description:
          %(cmd)s is an internal command to use SSH with plain text password
          authentication, which is used to execute commands on remote computer systems,
          e.g. %(cmd)s /tmp/pwd root@192.168.0.31 ls -la /
        '''
    elif cmd == 'univention-scp':
        text = '''
        %(cmd)s: internal managing tool for UCS clients
        Copyright (c) 2001-2026 Univention GmbH, Germany

        Syntax:
          %(cmd)s [-timeout <seconds>] [--no-split] <pwd_file> [scp-options] files [user@]host:[path]
          %(cmd)s [--help] [--version]

        Parameters:
          pwd_file                         file containing the password
          scp-options                      additional options to scp
          files                            source files or directories
          user                             user to log in as on the remote system
          host                             name or ip address of the remote system
          path                             destination on the remote system

        Options:
          -timeout SECONDS:                Specify a timeout in seconds, after
                                           which the ssh connection is forcefully terminated.
          --no-split:                      don't word-split command
          -h | --help | -?:                print this usage message
          --version:                       print version information

        Description:
          %(cmd)s is an internal command to use SCP with plain text password
          authentication, which is used to copy files from/to remote computer systems,
          e.g. %(cmd)s /tmp/pwd file.txt root@192.168.0.31:
        '''
    else:
        text = '''
        %(cmd)s: internal managing tool for UCS clients
        Copyright (c) 2001-2026 Univention GmbH, Germany

        Syntax:
          %(cmd)s [-timeout <seconds>] [--no-split] <pwd_file> [rsync-options] files [user@]host:[path]
          %(cmd)s [--help] [--version]

        Parameters:
          pwd_file                         file containing the password
          rsync-options                    additional options to rsync
          files                            source files or directories
          user                             user to log in as on the remote system
          host                             name or ip address of the remote system
          path                             destination on the remote system

        Options:
          -timeout SECONDS:                Specify a timeout in seconds, after
                                           which the ssh connection is forcefully terminated.
          --no-split:                      don't word-split command
          -h | --help | -?:                print this usage message
          --version:                       print version information

        Description:
          %(cmd)s is an internal command to use RSYNC with plain text password
          authentication, which is used to copy files from/to remote computer systems,
          e.g. %(cmd)s /tmp/pwd file.txt root@192.168.0.31:
        '''
    print(dedent(text % {'cmd': cmd}).strip("\n"), file=out)


def display_version(out=sys.stdout):
    """Display program name and version string."""
    print('%s 13.6.0' % (cmd,), file=out)


UNIVENTION_SSH_ASKPASS_HELPER = '/usr/lib/univention-ssh/univention-ssh-askpass'
UNIVENTION_SSH_ASKPASS_ENV = 'UNIVENTION_SSH_ASKPASS'

try:
    program = sys.argv.pop(0)
    if program == UNIVENTION_SSH_ASKPASS_HELPER:
        try:
            print(os.environ[UNIVENTION_SSH_ASKPASS_ENV], end="")
            sys.exit(0)
        except KeyError:
            sys.exit(1)
    cmd = os.path.basename(program)
    split = True
    timeout = 3600

    while True:
        arg = sys.argv.pop(0)
        if arg == '--version':
            display_version()
            sys.exit(0)
        elif arg in ('-?', '--help', '-h'):
            display_help()
            sys.exit(0)
        elif arg == '-timeout':
            arg = sys.argv.pop(0)
            timeout = int(arg)
        elif arg == '--no-split':
            split = False
        elif arg.startswith('-'):
            raise IndexError()
        else:
            pw_filename = arg
            with open(arg) as pw_filehandle:
                password = pw_filehandle.readline().rstrip('\n')

            options = []
            if not ucr.is_true('ssh/StrictHostKeyChecking'):
                options.extend(['-o', 'StrictHostKeyChecking=no'])
            options.extend(['-o', 'ControlPath=none'])
            if not ucr.is_true('ssh/PubkeyAuthentication'):
                options.extend(['-o', 'PubkeyAuthentication=no'])

            command = {
                'univention-ssh': ['ssh', *options],
                'univention-scp': ['scp', *options],
                'univention-ssh-rsync': ['rsync', '-e', 'ssh %s' % ' '.join(shlex.quote(o) for o in options)],
            }[cmd]
            if split:
                sys.argv = shlex.split(' '.join(sys.argv))
            if cmd == 'univention-ssh-rsync':
                i = 1
                while i < len(sys.argv):
                    arg = sys.argv[i]
                    if arg in ('-e', '--rsh'):
                        print('Overwriting %s option' % (arg,), file=sys.stderr)
                        del sys.argv[i:i + 2]
                    elif arg.startswith('--rsh='):
                        print('Overwriting --rsh option', file=sys.stderr)
                        del sys.argv[i]
                    else:
                        i += 1
            command += sys.argv

            pid = os.fork()
            if pid < 0:
                print('Failed to fork child process', file=sys.stderr)
            elif pid == 0:  # child
                # the password for the grandchild
                os.environ[UNIVENTION_SSH_ASKPASS_ENV] = password
                # path to self
                os.environ['SSH_ASKPASS'] = UNIVENTION_SSH_ASKPASS_HELPER
                os.environ['SSH_ASKPASS_REQUIRE'] = 'force'

                os.execvp(command[0], command)  # noqa: S606

                print('Failed to exec %s' % (' '.join(command),), file=sys.stderr)
            else:  # parent
                # close all file descriptors
                max_fd = os.sysconf('SC_OPEN_MAX')
                if hasattr(os, 'closerange'):
                    os.closerange(0, max_fd)
                else:
                    for fd in range(max_fd):
                        try:
                            os.close(fd)
                        except OSError:
                            pass
                # re-open stdin, stdout, stderr
                for target_fd in range(3):
                    mode = [os.O_RDONLY, os.O_WRONLY, os.O_WRONLY][target_fd]
                    fd = os.open(os.devnull, mode)
                    if fd != target_fd:
                        os.dup2(fd, target_fd)
                        os.close(fd)

                try:
                    # setup timeout
                    def handler(signum, frame):
                        raise TimeoutError()
                    sig_alrm = signal.signal(signal.SIGALRM, handler)  # alarm
                    signal.signal(signal.SIGHUP, handler)  # hang up
                    signal.signal(signal.SIGINT, handler)  # interrupt
                    signal.signal(signal.SIGTERM, handler)  # terminate
                    signal.signal(signal.SIGSEGV, handler)  # memory corrupt

                    signal.alarm(timeout)

                    # wait for child and pass its exit status / signal
                    cpid, status = os.waitpid(pid, 0)

                    # restore SIGALRM
                    signal.signal(signal.SIGALRM, sig_alrm)

                    if os.WIFEXITED(status):
                        rc = os.WEXITSTATUS(status)
                        sys.exit(rc)
                    elif os.WIFSIGNALED(status):
                        sig = os.WTERMSIG(status)
                        current_pid = os.getpid()
                        os.kill(current_pid, sig)
                except (OSError, TimeoutError):
                    print('Signal.', file=sys.stderr)
                    os.kill(pid, signal.SIGTERM)
                    cpid, status = os.waitpid(pid, os.WNOHANG)
                    if (cpid, status) == (0, 0):
                        time.sleep(1)
                        os.kill(pid, signal.SIGKILL)
                        cpid, status = os.waitpid(pid, 0)
            # the buck stops here: fatal error
            sys.exit(1)
except KeyError:
    print('Unknown command %s' % (cmd,), file=sys.stderr)
except IndexError:
    display_help(sys.stderr)
except OSError:
    print('Failed to read password from %s' % (pw_filename,), file=sys.stderr)

sys.exit(2)
