#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#  SPDX-License-Identifier: Apache-2.0

import logging
import os
import subprocess
import time
from typing import Any, Dict

from res.utils import sssd_utils


def join_active_directory(auth_entry: Dict[str, Any], logger: logging.Logger) -> None:
    # Rename the existing SSSD config before joining AD.
    # Otherwise, the process could fail if the previous VDI session has already joined AD.
    back_up_sssd(logger)

    otp = auth_entry["otp"]
    domain_controller = auth_entry["domain_controller"]
    hostname = auth_entry["hostname"]

    cmd = [
        "realm",
        "join",
        "--one-time-password",
        otp,
        "--computer-name",
        hostname.upper(),
        "--client-software",
        "sssd",
        "--server-software",
        "active-directory",
        "--membership-software",
        "adcli",
        "--verbose",
        domain_controller,
    ]

    try:
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        logger.info(result.stdout)
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to join AD: {e.stderr}")
        return

    configure_sssd(logger)


def is_in_active_directory(_logger: logging.Logger) -> bool:
    return sssd_utils.is_in_active_directory()

def connect_to_active_directory(logger: logging.Logger):
    configure_sssd(logger)

    base_os = os.getenv("RES_BASE_OS")
    if base_os in ["amzn2", "rhel8", "rhel9", "rocky9"]:
        subprocess.check_call(
            [
                "sudo",
                "authconfig",
                "--enablemkhomedir",
                "--enablesssdauth",
                "--enablesssd",
                "--updateall",
            ],
            stdout=subprocess.PIPE,
        )
    elif base_os == "amzn2023":
        subprocess.check_call(
            ["sudo", "authselect", "select", "sssd", "with-mkhomedir", "--force"],
            stdout=subprocess.PIPE,
        )
    elif base_os.startswith("ubuntu"):
        subprocess.check_call(
            ["sudo", "pam-auth-update", "--enable", "sss", "--force"],
            stdout=subprocess.PIPE,
        )

def back_up_sssd(logger: logging.Logger) -> None:
    if os.path.exists(sssd_utils.SSSD_FILE_PATH):
        logger.info(f"Back up existing SSSD config file: {sssd_utils.SSSD_FILE_PATH}")

        os.rename(sssd_utils.SSSD_FILE_PATH, f"{sssd_utils.SSSD_FILE_PATH}.{time.time()}")

def configure_sssd(logger: logging.Logger):
    try:
        sssd_utils.restart_sssd()
    except Exception as e:
        # Avoid throwing exceptions in the long-running application.
        # The application should continue monitoring and trying to restart SSSD upon SSSD config updates.
        logger.error(f"Failed to restart SSSD: {e}")
