#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Like what you see? Join us!
# https://www.univention.com/about-us/careers/vacancies/
#
# Copyright 2021-2022 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.
#

from __future__ import print_function

import os
import re
import subprocess
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace  # noqa: F401
from glob import glob
from hashlib import md5
from typing import Container, Dict, Iterable, List, Optional, Set, Tuple, Union  # noqa: F401

from univention.ldap_cache.cache import Caches, get_cache  # noqa: F401
from univention.ldap_cache.cache.shard_config import add_shard_to_config, rm_shard_from_config
from univention.uldap import getMachineConnection

listener_template = '''#!/usr/bin/python3
from __future__ import absolute_import

from univention.ldap_cache.listener_module import LdapCacheHandler

name = {name!r}

class QueryHandler(LdapCacheHandler):
    class Configuration(LdapCacheHandler.Configuration):
        name = {name!r}
        def get_description(self):
            return 'Automatically created listener module for the univention-group-membership-member cache (%s)' % ({desc!r},)
        def get_ldap_filter(self):
            return {filter!r}
        def get_attributes(self):
            return {attrs!r}
'''


def _md5sum(value):
    # type: (str) -> str
    m = md5()
    m.update(value.encode('utf-8'))
    return m.hexdigest()


def _query_objects(query, attrs):
    # type: (str, Iterable[Union[bytes, str]]) -> List[Tuple[str, Dict[str, List[bytes]]]]
    lo = getMachineConnection()
    attrs = [attr.decode('utf-8') if isinstance(attr, bytes) else attr for attr in attrs]
    print('Querying', query, 'with', attrs)
    return lo.search(query, attr=attrs)


def _get_queries(caches, cache_names=None):
    # type: (Caches, Optional[Container[str]]) -> Dict[str, Tuple[List, Set]]
    queries = {}  # type: Dict[str, Tuple[List, Set]]
    for name, cache in caches:
        if cache_names is not None and name not in cache_names:
            continue
        for shard in cache.shards:
            queries.setdefault(shard.ldap_filter, ([], set()))
            shards, attrs = queries[shard.ldap_filter]
            shards.append(shard)
            attrs.add(shard.key)
            attrs.add(shard.value)
            attrs.update(shard.attributes)
    return queries


def add_cache(args):
    # type: (Namespace) -> None
    add_shard_to_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def rm_cache(args):
    # type: (Namespace) -> None
    rm_shard_from_config(args.db_name, args.single_value, args.reverse, args.key, args.value, args.ldap_filter)


def cleanup(args):
    # type: (Namespace) -> None
    for name, db in get_cache():
        db.cleanup()


def rebuild(args):
    # type: (Namespace) -> None
    caches = get_cache()
    cache_names = args.cache_name
    if not cache_names:
        cache_names = sorted([x[0] for x in caches])
    print('Rebuilding', cache_names)
    for name, cache in caches:
        if name in cache_names:
            cache.clear()
    for query, (_caches, attrs) in _get_queries(caches, cache_names).items():
        print('Searching for', query)
        attrs.discard('dn')
        i = 0
        for obj in _query_objects(query, attrs):
            i += 1
            if i % 1000 == 0:
                print('\rProcessing object #', i, end='')
                sys.stdout.flush()
                cleanup(args)
            for shard in _caches:
                shard.add_object(obj)
        if i >= 1000:
            print()
        print('Added', i, 'objects')


def create_listener_modules(args):
    # type: (Namespace) -> None
    listener_dir = '/usr/lib/univention-directory-listener/system/'
    existing_listeners = set(glob(os.path.join(listener_dir, 'ldap-cache-*')))
    for query, (_caches, attrs) in _get_queries(get_cache()).items():
        attrs.discard('dn')
        attrs.discard('entryUUID')
        listener_name = 'ldap-cache-%s' % _md5sum(query)
        fname = os.path.join(listener_dir, '%s.py' % listener_name)
        existing_listeners.discard(fname)
        print('Writing', fname, 'for', query)
        with open(fname, 'w') as fd:
            fd.write(listener_template.format(
                name=listener_name,
                desc=', '.join(sorted({cache.db_name for cache in _caches})),
                filter=query,
                attrs=sorted(attrs),
            ))
    for fname in existing_listeners:
        print('Removing', fname)
        os.unlink(fname)
    subprocess.call(['service', 'univention-directory-listener', 'restart'])
    for fname in existing_listeners:
        # also remove status "initialized" so that the listener may be added again in the future
        # needs to be done after the listener is running again
        listener_name = os.path.basename(fname)[:-3]
        try:
            os.unlink(os.path.join('/var/lib/univention-directory-listener/handlers/', listener_name))
        except EnvironmentError:
            pass


def list_caches(args):
    # type: (Namespace) -> None
    caches = get_cache()
    for name, cache in caches:
        print(name)
        print(' The following objects store data:')
        for shard in cache.shards:
            print('  ', shard.ldap_filter)
            if shard.reverse:
                key = shard.value
                value = '[{}]'.format(shard.key)
            elif shard.single_value:
                key = shard.key
                value = shard.value
            else:
                key = shard.key
                value = '[{}]'.format(shard.value)
            print('    ', key, '=>', value)


def query(args):
    # type: (Namespace) -> None
    caches = get_cache()
    cache = caches.get_sub_cache(args.cache_name)
    if not cache:
        print('No cache named', args.cache_name)
        return
    data = cache.load()
    regex = None
    if args.pattern:
        try:
            regex = re.compile(args.pattern)
        except re.error:
            print('Broken pattern')
            return
    for key in sorted(data):
        if regex is None or regex.search(key):
            print(key, '=>', data[key])


def main():
    # type: () -> None
    description = 'The LDAP cache stores some portions of the LDAP database so that it can be accessed fast and reliably.'
    parser = ArgumentParser(description=description)
    subparsers = parser.add_subparsers(
        description='type %(prog)s <action> --help for further help and possible arguments',
        metavar='action',
        required=True,
    )

    subparser = subparsers.add_parser('query', description='Queries a sub cache and shows the value(s). Mainly for test purposes, this tool is not meant for real applications', help='Query the cache')
    subparser.add_argument('cache_name', help='The name of the sub cache. See "list"')
    subparser.add_argument('pattern', nargs='?', help='Queries the key with this regexp')
    subparser.set_defaults(func=query)

    subparser = subparsers.add_parser('list', description='Lists all sub caches of the cache. Each sub cache may be fed from multiple sources', help='List all parts of the cache')
    subparser.set_defaults(func=list_caches)

    subparser = subparsers.add_parser('rebuild', description='Rebuild the cache completely, retrieve the objects and overwrite all previous data', help='Rebuild the cache')
    subparser.add_argument('cache_name', nargs='*', help='The cache consists of different parts. You can only rebuild certain parts of the cache. See "list"')
    subparser.set_defaults(func=rebuild)

    subparser = subparsers.add_parser('create-listener-modules', description='Automatically creates listener modules that will eventually fill the cache (and removes unnecessary); restarts the univention-directory-listener. May be needed after shards are added to /removed from the cache', help='Create listener modules')
    subparser.set_defaults(func=create_listener_modules)

    subparser = subparsers.add_parser('add-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=add_cache)

    subparser = subparsers.add_parser('rm-cache', description=SUPPRESS)
    subparser.add_argument('db_name')
    subparser.add_argument('--single-value', action='store_true')
    subparser.add_argument('--reverse', action='store_true')
    subparser.add_argument('key')
    subparser.add_argument('value')
    subparser.add_argument('ldap_filter')
    subparser.set_defaults(func=rm_cache)

    subparser = subparsers.add_parser('cleanup', description='Over the time, the database may become polluted. Shrink all caches to their optimal size if possible.', help='Clean up cache')
    subparser.set_defaults(func=cleanup)

    args = parser.parse_args()
    func = args.func
    func(args)


if __name__ == '__main__':
    main()
