#!/usr/bin/python3
# -*- coding: utf-8 -*-
# pylint: disable-msg=C0103
#
# Copyright 2004-2021 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

"""
Check if users are member of their primary group.
"""

from __future__ import print_function

import sys
from argparse import ArgumentParser

import ldap
from ldap.filter import filter_format
from univention.config_registry import ConfigRegistry
import univention.uldap


def info(msg, *args, **kwargs):
	"""Print info."""
	print(msg % (args or kwargs))


def warn(msg, *args, **kwargs):
	"""Print warning."""
	print('Warning: ' + (msg % (args or kwargs)), file=sys.stderr)
	warn.warnings += 1


warn.warnings = 0


def fatal(msg, *args, **kwargs):
	"""Print error."""
	print('Error: ' + (msg % (args or kwargs)), file=sys.stderr)
	sys.exit(1)


def check_primary(conn, basedn):
	"""Check if users are member of their primary group."""
	info("Checking if users are member of their primary group...")
	try:
		# GID's will only be found in posixAccount
		user_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, '(objectClass=posixAccount)', ['gidNumber', 'uid'])
	except ldap.NO_SUCH_OBJECT:
		fatal("ldap search in %s failed (no such base dn)", basedn)
	count_changes = 0
	for user_dn, account in user_result:
		user_uid = account['uid'][0].decode('utf-8')
		user_gid = account.get('gidNumber', [])[0].decode('utf-8')
		if not user_gid:
			warn("posixAccount without gidNumber: %s", user_dn)

		# search corresponding group
		group_result = conn.search_s(
			basedn, ldap.SCOPE_SUBTREE,
			filter_format('(&(objectClass=univentionGroup)(gidNumber=%s))', (user_gid,)),
			['uniqueMember', 'memberUid']
		)

		# there must be exactly one group with this gid
		if len(group_result) > 1:
			warn("found more than one univentionGroup for gidNumber=%s!", user_gid)
		elif len(group_result) < 1 and not user_gid == "0":
			warn("found no univentionGroup for gidNumber=%s!", user_gid)
		# we change them all -- the user needs to delete all but one of them
		for group_dn, group in group_result:
			# look for the needed entry
			group_member_dns = [dn.decode('utf-8').lower() for dn in group.get('uniqueMember', [])]
			group_member_uids = [uid.decode('utf-8').lower() for uid in group.get('memberUid', [])]
			modlist = []
			if user_dn.lower() not in group_member_dns:
				modlist.append((ldap.MOD_ADD, 'uniqueMember', user_dn.encode('utf-8')))
			if user_uid.lower() not in group_member_uids:
				modlist.append((ldap.MOD_ADD, 'memberUid', user_uid.encode('utf-8')))
			# no entry found, going to add one
			if modlist:
				info("Adding uniqueMember and memberUid entry for '%s' in '%s'", user_dn, group_dn)
				try:
					conn.modify_s(group_dn, modlist)
					count_changes += 1
				except ldap.LDAPError:
					warn("failed to modify group %s", group_dn)
	info("Checked %d posixAccounts, fixed %d issues.", len(user_result), count_changes)


def check_groups(conn, basedn):
	"""Check if members of group exist."""
	info("Checking if group-members exist...")
	try:
		group_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, '(objectClass=posixGroup)', ['uniqueMember', 'memberUid'])
	except ldap.NO_SUCH_OBJECT:
		fatal("ldap search in %s failed (no such base dn)", basedn)

	count_changes = 0
	for group_dn, group in group_result:
		count_changes += check_groups_by_dn(conn, group_dn, group)
		count_changes += check_groups_by_uid(conn, basedn, group_dn, group)

	info("Checked %d posixGroups, fixed %d issues.", len(group_result), count_changes)


def check_groups_by_dn(conn, group_dn, group):
	"""Check by 'uniqueMember'."""
	group_member_dns = [dn.decode('utf-8') for dn in group.get('uniqueMember', [])]
	count_changes = 0
	remmembers = set()
	for member_dn in group_member_dns:
		# Split uid=USER, cn=user,dc=FQDN
		try:
			member_filter, base = member_dn.split(',', 1)
		except ValueError:
			remmembers.add(member_dn)
			continue

		try:
			member_result = conn.search_s(base, ldap.SCOPE_ONELEVEL, member_filter, ['objectClass'])
		except ldap.LDAPError:
			warn("Manual: Search for member DN '%s' of group '%s' failed", member_dn, group_dn)
		else:
			if len(member_result) > 1:
				warn("Manual: Multiple members for DN '%s' of group '%s'", member_dn, group_dn)
			elif len(member_result) < 1:
				warn("No member for DN '%s', will be removed", member_dn)
				remmembers.add(member_dn)
	for member_dn in remmembers:
		info("Removing member DN '%s' from '%s'", member_dn, group_dn)
		modlist = [(ldap.MOD_DELETE, 'uniqueMember', member_dn.encode('utf-8'))]
		try:
			conn.modify_s(group_dn, modlist)
			count_changes += 1
		except ldap.LDAPError:
			warn("failed to remove DN '%s' from group '%s'", member_dn, group_dn)
	return count_changes


def check_groups_by_uid(conn, basedn, group_dn, group):
	"""Check by 'memberUid'."""
	group_member_uids = [uid.decode('utf-8') for uid in group.get('memberUid', [])]
	count_changes = 0
	remmembers = set()
	for member_uid in group_member_uids:
		try:
			member_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, filter_format('(uid=%s)', (member_uid,)), ['objectClass'])
		except ldap.LDAPError:
			warn("Manual: Search for member UID '%s' of group '%s' failed", member_uid, group_dn)
		else:
			if len(member_result) > 1:
				warn("Manual: Multiple members for UID '%s' of group '%s'", member_uid, group_dn)
			elif len(member_result) < 1:
				warn("No member for UID '%s', will be removed", member_uid)
				remmembers.add(member_uid)
	for member_uid in remmembers:
		info("Removing member UID '%s' from '%s'", member_uid, group_dn)
		modlist = [(ldap.MOD_DELETE, 'memberUid', member_uid.encode('utf-8'))]
		try:
			conn.modify_s(group_dn, modlist)
			count_changes += 1
		except ldap.LDAPError:
			warn("Failed to remove UID '%s' from group '%s'", member_uid, group_dn)
	return count_changes


def main():
	# type: () -> None
	"""Check group membership."""
	parser = ArgumentParser(description=__doc__)
	parser.add_argument("-b", "--base-dn", dest="basedn", action="store", help="ldap base DN for user search")
	parser.add_argument("-c", "--check", dest="check", action="store_true", help="Only check, do not modify")
	options = parser.parse_args()

	ucr = ConfigRegistry()
	ucr.load()
	basedn = ucr['ldap/base']

	conn = univention.uldap.getAdminConnection().lo

	if options.basedn:
		basedn = options.basedn
	if options.check:
		conn.modify_s = lambda dn, modlist: None

	check_primary(conn, basedn)
	check_groups(conn, basedn)
	if warn.warnings:
		info("There were %d warning(s)!", warn.warnings)
		sys.exit(2)
	else:
		sys.exit(0)


if __name__ == '__main__':
	main()
