#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Univention Samba
#  helper script: kerberize a user account
#
# Like what you see? Join us!
# https://www.univention.com/about-us/careers/vacancies/
#
# Copyright 2001-2023 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

from __future__ import print_function

import os
import codecs
import sys
import getopt
import hashlib
from datetime import datetime

import ldap
from ldap.filter import filter_format

import univention.config_registry
import univention.debug as ud
import univention.uldap
import univention.lib.policy_result

ud.init('/dev/null', ud.FLUSH, ud.FUNCTION)

configRegistry = univention.config_registry.ConfigRegistry()
configRegistry.load()

lo = univention.uldap.getAdminConnection()
krbbase = 'ou=krb5,' + configRegistry['ldap/base']
realm = configRegistry['kerberos/realm']

UNIXDAY = 3600 * 24


def _get_samba_password_history(newpassword, smbpwhistory, smbpwhlen):
	# calculate the password hash & salt
	# in binary for calculating the md5:
	salt = os.urandom(16)
	# we have to have that in hex:
	hexsalt = codecs.encode(salt, 'hex').upper().decode('ASCII')
	# we need the ntpwd binary data to
	pwd = codecs.decode(newpassword, 'hex')
	# calculating hash. stored as a 32byte hex in sambaPasswordHistory,
	# syntax like that: [Salt][MD5(Salt+Hash)]
	#	First 16bytes ^		^ last 16bytes.
	pwdhash = hashlib.md5(salt + pwd).hexdigest().upper()
	smbpwhash = hexsalt + pwdhash

	# split the history
	pwlist = smbpwhistory.strip().split(' ')
	# append new hash
	pwlist.append(smbpwhash)
	# strip old hashes
	pwlist = pwlist[-smbpwhlen:]
	# build history
	smbpwhistory = ''.join(pwlist)
	return smbpwhistory


def unixdayToKrb5Date(unixday):
	date = datetime.fromtimestamp(float(unixday.decode('ASCII'))).strftime("%Y%m%d%H%M%S") + 'Z'
	return date.encode('ASCII')


def nt_password_to_arcfour_hmac_md5(nt_password):
	# all arcfour-hmac-md5 keys begin this way
	key = b'0\x1d\xa1\x1b0\x19\xa0\x03\x02\x01\x17\xa1\x12\x04\x10'

	for i in range(0, 16):
		o = nt_password[2 * i:2 * i + 2]
		key += chr(int(o, 16)).encode('ISO8859-1')
	return key


def main():
	username = None
	optlist, mail_user = getopt.getopt(sys.argv[1:], 'u:')
	for option, value in optlist:
		if option == '-u':
			username = value

	if not username:
		sys.exit(0)

	ldap_attr = ['uid', 'krb5Key', 'sambaNTPassword', 'sambaLMPassword', 'sambaAcctFlags', 'objectClass', 'userPassword', 'krb5PasswordEnd', 'sambaPwdCanChange', 'sambaPwdMustChange', 'sambaPwdLastSet', 'krb5KDCFlags', 'shadowLastChange', 'sambaPasswordHistory', 'shadowMax']
	for dn, attrs in lo.search(filter=filter_format('(&(objectClass=sambaSamAccount)(sambaNTPassword=*)(uid=%s)(!(objectClass=univentionWindows)))', [username]), attr=ldap_attr):
		if b'univentionHost' in attrs['objectClass']:
			continue
		if attrs['sambaNTPassword'][0] != b"NO PASSWORDXXXXXXXXXXXXXXXXXXXXX":

			if attrs['uid'][0] == b'root':
				print('Skipping user root ')
				continue

			# check if the user was disabled
			sambaAcctFlags = attrs.get('sambaAcctFlags', [b''])[0]
			disabled = b'D' in sambaAcctFlags

			ocs = []
			ml = []
			if b'krb5Principal' not in attrs['objectClass']:
				ocs.append('krb5Principal')
				principal = b'%s@%s' % (attrs['uid'][0], realm.encode('UTF-8'))
				ml.append(('krb5PrincipalName', None, principal))

			if disabled:
				flag = b'256'
			else:
				flag = b'126'

			if not sambaAcctFlags:
				ml.append(('sambaAcctFlags', None, b'[U          ]'))

			policy_result = univention.lib.policy_result.policy_result(dn)[0]
			if not attrs.get('sambaPasswordHistory', False):
				smbpwhlen = int(policy_result.get('univentionPWHistoryLen', [0])['0'])
				ml.append(('sambaPasswordHistory', None, _get_samba_password_history(dn, attrs.get('sambaNTPassword', [b''])[0], '', smbpwhlen)))

			if attrs.get('sambaPwdLastSet', False):
				usersPWExpireInterval = int(policy_result.get('univentionPWExpiryInterval', [0])['0'])

				# if not 'krb5PasswordEnd' in attrs['objectClass']:
				#	ocs.append('krb5PasswordEnd')
				oldkrb5PasswordEndValue = attrs.get('krb5PasswordEnd', [None])[0]
				# if not 'sambaPwdCanChange' in attrs['objectClass']:
				#	ocs.append('sambaPwdCanChange')
				oldsambaPwdCanChangeValue = attrs.get('sambaPwdCanChange', [None])[0]
				# if not 'sambaPwdMustChange' in attrs['objectClass']:
				#	ocs.append('sambaPwdMustChange')
				oldsambaPwdMustChangeValue = attrs.get('sambaPwdMustChange', [None])[0]
				oldshadowMaxValue = attrs.get('shadowMax', [None])[0]
				oldshadowLastChangeValue = attrs.get('shadowLastChange', [None])[0]
				sambaPwdLastSetValue = int(attrs[('sambaPwdLastSet')][0].decode('ASCII'))
				# Debug # print 'SambaPwdLastSet "%d", "%d", "%d"' %(sambaPwdLastSetTimestamp, pwdlifetime, unixday)
				shadowLastChangeValue = str(sambaPwdLastSetValue // UNIXDAY).encode('ASCII')
				sambaPwdCanChangeValue = str(sambaPwdLastSetValue + UNIXDAY).encode('ASCII')
				if usersPWExpireInterval:
					# print 'DEBUG: PWExpireInterval policy valid, calculating and setting expiring dates'
					sambaPwdMustChangeValue = str(sambaPwdLastSetValue + int(usersPWExpireInterval * UNIXDAY)).encode('ASCII')
					krb5PasswordEndValue = str(unixdayToKrb5Date(sambaPwdMustChangeValue)).encode('ASCII')
					shadowMaxValue = str(usersPWExpireInterval).encode('ASCII')
				else:
					# print 'DEBUG: PWExpireInterval policy not set, removing expire intervals and dates'
					sambaPwdMustChangeValue = None
					krb5PasswordEndValue = None
					shadowMaxValue = None
				ml.extend([
					('krb5PasswordEnd', oldkrb5PasswordEndValue, krb5PasswordEndValue),
					('sambaPwdCanChange', oldsambaPwdCanChangeValue, sambaPwdCanChangeValue),
					('sambaPwdMustChange', oldsambaPwdMustChangeValue, sambaPwdMustChangeValue),
					('shadowMax', oldshadowMaxValue, shadowMaxValue),
					('shadowLastChange', oldshadowLastChangeValue, shadowLastChangeValue)
				])
			else:
				print('Could not find attribute "sambaPwdLastSet". Skipping generating of "krb5PasswordEnd".')

			if b'krb5KDCEntry' not in attrs['objectClass']:
				ocs.append('krb5KDCEntry')
				ml.extend([
					('krb5MaxLife', None, b'86400'),
					('krb5MaxRenew', None, b'604800'),
					('krb5KeyVersionNumber', None, b'1'),
				])

			old_flag = attrs.get('krb5KDCFlags', [])
			old_keys = attrs.get('krb5Key', [])

			ml.extend([
				('krb5Key', old_keys, nt_password_to_arcfour_hmac_md5(attrs['sambaNTPassword'][0])),
				('krb5KDCFlags', old_flag, flag)
			])

			if attrs.get('sambaLMPassword', [None])[0] not in [b"NO PASSWORDXXXXXXXXXXXXXXXXXXXXX", None]:
				old_password = attrs.get('userPassword', [])
				if disabled:
					ml.extend([
						('userPassword', old_password, [b'{LANMAN}!%s' % attrs['sambaLMPassword'][0]])
					])
				else:
					ml.extend([
						('userPassword', old_password, [b'{LANMAN}%s' % attrs['sambaLMPassword'][0]])
					])

			if ocs:
				print('Adding Kerberos key for %r...' % (dn,), end=' ')
				ml.insert(0, ('objectClass', None, [x.encode('UTF-8') for x in ocs]))

			try:
				lo.modify(dn, ml)
			except ldap.ALREADY_EXISTS:
				print('already exists')
			else:
				print('done')

		else:
			print('Can not add Kerberos key for %s...' % repr(dn), end=' ')
			print('no password set')


if __name__ == '__main__':
	main()
