/*
 * Decompiled with CFR 0.152.
 */
package com.teamscale.core.authenticate.saml;

import com.teamscale.core.authenticate.saml.SamlAuthenticationOption;
import com.teamscale.core.authenticate.saml.SamlUser;
import com.teamscale.core.authenticate.saml.SamlUtils;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.core.UriBuilder;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.shibboleth.shared.codec.Base64Support;
import net.shibboleth.shared.codec.DecodingException;
import net.shibboleth.shared.resolver.ResolverException;
import net.shibboleth.shared.security.impl.RandomIdentifierGenerationStrategy;
import net.shibboleth.shared.xml.ParserPool;
import net.shibboleth.shared.xml.XMLParserException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.conqat.lib.commons.assertion.CCSMAssert;
import org.conqat.lib.commons.date.DateTimeUtils;
import org.conqat.lib.commons.net.UrlUtils;
import org.conqat.lib.commons.string.StringUtils;
import org.jetbrains.annotations.VisibleForTesting;
import org.opensaml.core.config.InitializationException;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.XMLObjectBuilderFactory;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.messaging.encoder.MessageEncodingException;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.saml2.binding.encoding.impl.HTTPRedirectDeflateEncoder;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AudienceRestriction;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Condition;
import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedAttribute;
import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.Status;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.SingleSignOnService;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.credential.Credential;
import org.opensaml.xmlsec.encryption.support.DecryptionException;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureValidator;
import org.w3c.dom.Node;

public class SamlConfigurationCache {
    private static final String STATUS_CODE_PREFIX = "urn:oasis:names:tc:SAML:2.0:status:";
    private static final Logger LOGGER = LogManager.getLogger();
    private static SamlConfigurationCache instance;
    private final RandomIdentifierGenerationStrategy secureRandomIdGenerator = new RandomIdentifierGenerationStrategy();
    private final Map<String, String> ssoLocationCache = new HashMap<String, String>();
    private final Map<String, Credential> credentialCache = new HashMap<String, Credential>();
    private final Map<String, String> entityIdCache = new HashMap<String, String>();

    private SamlConfigurationCache() throws InitializationException {
        SamlUtils.ensureInitialized();
    }

    public static synchronized SamlConfigurationCache getInstance() {
        if (instance == null) {
            try {
                instance = new SamlConfigurationCache();
            }
            catch (InitializationException e) {
                LOGGER.fatal("Could not initialize SAML library: " + e.getMessage(), (Throwable)e);
                CCSMAssert.fail((String)"Failed to initialize SAML library! See log for details!");
            }
        }
        return instance;
    }

    public Optional<String> getRequestUrl(SamlAuthenticationOption samlOption, URI baseUri, @Nullable String redirectionTarget) {
        try {
            String requestUrl = this.buildRequestUrl(baseUri, samlOption);
            return Optional.of(requestUrl + "&RelayState=" + UrlUtils.encodeQueryParameter((String)StringUtils.emptyIfNull((String)redirectionTarget)));
        }
        catch (BadRequestException | IOException | ResolverException | XMLParserException | UnmarshallingException | MessageEncodingException e) {
            LOGGER.error("Building SAML request URL failed: " + e.getMessage(), e);
            return Optional.empty();
        }
    }

    private String buildRequestUrl(URI baseUri, SamlAuthenticationOption samlOption) throws IOException, ResolverException, MessageEncodingException, BadRequestException, XMLParserException, UnmarshallingException {
        AuthnRequest authnRequest = this.createAuthenticationRequest(baseUri, samlOption.serviceProviderId);
        String encodedRequest = UrlUtils.encodeQueryParameter((String)new RequestEncodingHelper().deflateAndBase64Encode((SAMLObject)authnRequest));
        String ssoLocation = this.getSsoLocation(samlOption.metadataXml);
        return ssoLocation + SamlConfigurationCache.determineUrlSeparator(ssoLocation) + "SAMLRequest=" + encodedRequest;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String getSsoLocation(String metadataXml) throws ResolverException, XMLParserException, UnmarshallingException {
        Map<String, String> map = this.ssoLocationCache;
        synchronized (map) {
            String location = this.ssoLocationCache.get(metadataXml);
            if (location == null) {
                location = SamlConfigurationCache.getSsoService(SamlUtils.parseEntityDescriptor(metadataXml)).getLocation();
                this.ssoLocationCache.put(metadataXml, location);
            }
            return location;
        }
    }

    private static String determineUrlSeparator(String location) {
        if (location.contains("?")) {
            return "&";
        }
        return "?";
    }

    private AuthnRequest createAuthenticationRequest(URI baseUri, String serviceProviderId) {
        XMLObjectBuilderFactory builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory();
        AuthnRequest authnRequest = (AuthnRequest)Objects.requireNonNull(builderFactory.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)).buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
        authnRequest.setIssueInstant(this.getNow());
        authnRequest.setProtocolBinding("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST");
        authnRequest.setID(this.getIdentifier());
        authnRequest.setAssertionConsumerServiceURL(UriBuilder.fromUri((URI)baseUri).path("api/auth/saml/authenticate").build(new Object[0]).toString());
        authnRequest.setIssuer(SamlConfigurationCache.createIssuer(builderFactory, serviceProviderId));
        return authnRequest;
    }

    private static Issuer createIssuer(XMLObjectBuilderFactory builderFactory, String serviceProviderId) {
        Issuer issuer = (Issuer)Objects.requireNonNull(builderFactory.getBuilder(Issuer.DEFAULT_ELEMENT_NAME)).buildObject(Issuer.DEFAULT_ELEMENT_NAME);
        issuer.setValue(serviceProviderId);
        return issuer;
    }

    @VisibleForTesting
    Instant getNow() {
        return DateTimeUtils.now();
    }

    String getIdentifier() {
        return this.secureRandomIdGenerator.generateIdentifier();
    }

    private static SingleSignOnService getSsoService(EntityDescriptor entityDescriptor) throws ResolverException {
        for (SingleSignOnService ssoService : entityDescriptor.getIDPSSODescriptor("urn:oasis:names:tc:SAML:2.0:protocol").getSingleSignOnServices()) {
            if (!ssoService.getBinding().equals("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect")) continue;
            return ssoService;
        }
        throw new ResolverException("Did not find any redirect binding!");
    }

    private Credential getCredential(String metadataXml) throws BadRequestException {
        Credential credential = this.credentialCache.get(metadataXml);
        if (credential == null) {
            credential = SamlUtils.extractCredential(metadataXml);
            this.credentialCache.put(metadataXml, credential);
        }
        return credential;
    }

    private synchronized String getEntityId(String metadataXml) throws BadRequestException {
        String entityId = this.entityIdCache.get(metadataXml);
        if (entityId == null) {
            entityId = SamlUtils.extractEntityId(metadataXml);
            this.entityIdCache.put(metadataXml, entityId);
        }
        return entityId;
    }

    public synchronized SamlUser getUserFromResponse(String samlResponse, Collection<SamlAuthenticationOption> samlOptions) {
        Response response = this.parseResponse(samlResponse);
        SamlAuthenticationOption samlServer = this.findSamlServerForResponse(response, samlOptions);
        Assertion assertion = SamlConfigurationCache.getFirstAssertion(response, samlServer);
        this.checkConditions(assertion, samlServer);
        this.checkSignature(Optional.ofNullable(assertion.getSignature()).orElseGet(() -> ((Response)response).getSignature()), samlServer);
        return new SamlUser(SamlConfigurationCache.extractSubjectNameId(assertion, samlServer), SamlConfigurationCache.extractAttributes(assertion, samlServer), samlServer);
    }

    private Response parseResponse(String samlResponse) throws BadRequestException {
        try {
            String xml = this.convertResponseToXmlString(samlResponse);
            XMLObject xmlObject = XMLObjectSupport.unmarshallFromReader((ParserPool)Objects.requireNonNull(XMLObjectProviderRegistrySupport.getParserPool()), (Reader)new StringReader(xml));
            if (!(xmlObject instanceof Response)) {
                throw new BadRequestException("Expected SAML Response, but received " + String.valueOf(xmlObject.getClass()));
            }
            return (Response)xmlObject;
        }
        catch (DecodingException | XMLParserException | UnmarshallingException e) {
            throw new BadRequestException(SamlConfigurationCache.appendCauseMessages("Invalid payload: no valid SAML 2.0 XML", (Exception)e), e);
        }
    }

    @VisibleForTesting
    String convertResponseToXmlString(String samlResponse) throws DecodingException {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("SAML 2.0 response was " + samlResponse);
        }
        String xml = new String(Base64Support.decode((String)samlResponse), StandardCharsets.UTF_8);
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("SAML 2.0 response as XML was " + xml);
        }
        return xml;
    }

    private SamlAuthenticationOption findSamlServerForResponse(Response response, Collection<SamlAuthenticationOption> samlOptions) throws BadRequestException {
        if (response.getIssuer() == null) {
            throw new BadRequestException("Missing issuer in SAML response!");
        }
        String issuer = response.getIssuer().getValue();
        for (SamlAuthenticationOption samlOption : samlOptions) {
            if (issuer == null || !issuer.equals(this.getEntityId(samlOption.metadataXml))) continue;
            return samlOption;
        }
        throw new BadRequestException("No matching SAML server option found for issuer " + issuer);
    }

    public static Assertion getFirstAssertion(Response response, SamlAuthenticationOption option) throws BadRequestException {
        List encryptedAssertions = response.getEncryptedAssertions();
        if (!encryptedAssertions.isEmpty()) {
            try {
                return SamlUtils.getDecrypter(option).decrypt((EncryptedAssertion)encryptedAssertions.get(0));
            }
            catch (DecryptionException e) {
                throw new BadRequestException(SamlConfigurationCache.appendCauseMessages("Could not decrypt encrypted assertion", (Exception)((Object)e)), (Throwable)e);
            }
        }
        List assertions = response.getAssertions();
        if (assertions.isEmpty()) {
            throw new BadRequestException(SamlConfigurationCache.extractNoAssertionMessage(response));
        }
        return (Assertion)assertions.get(0);
    }

    public static String extractNoAssertionMessage(Response response) {
        ArrayList<String> messages = new ArrayList<String>();
        messages.add("No assertions returned by the IdP.");
        Optional<Status> status = Optional.ofNullable(response.getStatus());
        status.map(Status::getStatusMessage).map(XSString::getValue).ifPresent(message -> messages.add("Status Message: " + message));
        status.map(Status::getStatusCode).flatMap(SamlConfigurationCache::extractHierarchicalStatusCode).ifPresent(code -> messages.add("Status Code: " + code));
        return String.join((CharSequence)"\n", messages);
    }

    private static Optional<String> extractHierarchicalStatusCode(StatusCode statusCode) {
        String value = StringUtils.stripPrefix((String)statusCode.getValue(), (String)STATUS_CODE_PREFIX);
        return Optional.ofNullable(statusCode.getStatusCode()).flatMap(SamlConfigurationCache::extractHierarchicalStatusCode).map(child -> value + " > " + child).or(() -> Optional.of(value));
    }

    private void checkConditions(Assertion assertion, SamlAuthenticationOption samlServer) throws BadRequestException {
        Conditions conditions = assertion.getConditions();
        if (conditions == null) {
            return;
        }
        Instant now = this.getNow();
        if (conditions.getNotBefore() != null && now.isBefore(conditions.getNotBefore())) {
            throw new BadRequestException("Returned assertion is too new! May not be used before " + String.valueOf(conditions.getNotBefore()) + " but now is " + String.valueOf(now));
        }
        if (conditions.getNotOnOrAfter() != null && now.isAfter(conditions.getNotOnOrAfter())) {
            throw new BadRequestException("Returned assertion is too old! May not be used after " + String.valueOf(conditions.getNotOnOrAfter()) + " but now is " + String.valueOf(now));
        }
        for (Condition condition : conditions.getConditions()) {
            AudienceRestriction audienceRestriction;
            List<String> audiences;
            if (!(condition instanceof AudienceRestriction) || !(audiences = (audienceRestriction = (AudienceRestriction)condition).getAudiences().stream().map(XSURI::getURI).toList()).stream().noneMatch(uri -> samlServer.serviceProviderId.equals(uri))) continue;
            throw new BadRequestException("The assertion was created for a different audience! The Service Provider/Client ID configured in Teamscale does not match any of the following audiences: " + StringUtils.concat(audiences, (String)", "));
        }
    }

    private void checkSignature(Signature signature, SamlAuthenticationOption samlServer) throws BadRequestException {
        if (signature == null) {
            throw new BadRequestException("Response from IdP has no global signature and no signature in assertion!");
        }
        try {
            new SAMLSignatureProfileValidator().validate(signature);
            SignatureValidator.validate((Signature)signature, (Credential)this.getCredential(samlServer.metadataXml));
        }
        catch (SignatureException e) {
            throw new BadRequestException(SamlConfigurationCache.appendCauseMessages("Assertion signature could not be validated!", (Exception)((Object)e)), (Throwable)e);
        }
    }

    @VisibleForTesting
    static String extractSubjectNameId(Assertion assertion, SamlAuthenticationOption option) throws BadRequestException {
        Subject subject = assertion.getSubject();
        if (subject == null) {
            throw new BadRequestException("IdP returned assertion without subject!");
        }
        EncryptedID encryptedNameId = subject.getEncryptedID();
        if (encryptedNameId != null) {
            try {
                return SamlConfigurationCache.extractValueOfNameId((NameID)SamlUtils.getDecrypter(option).decrypt(encryptedNameId));
            }
            catch (DecryptionException e) {
                throw new BadRequestException(SamlConfigurationCache.appendCauseMessages("Failed to decrypt encrypted name ID.", (Exception)((Object)e)), (Throwable)e);
            }
        }
        return SamlConfigurationCache.extractValueOfNameId(subject.getNameID());
    }

    @VisibleForTesting
    static Map<String, List<String>> extractAttributes(Assertion assertion, SamlAuthenticationOption option) throws BadRequestException {
        HashMap<String, List<String>> attributes = new HashMap<String, List<String>>();
        for (AttributeStatement statement : assertion.getAttributeStatements()) {
            for (EncryptedAttribute encryptedAttribute : statement.getEncryptedAttributes()) {
                try {
                    SamlConfigurationCache.extractAttributes(attributes, SamlUtils.getDecrypter(option).decrypt(encryptedAttribute));
                }
                catch (DecryptionException e) {
                    throw new BadRequestException(SamlConfigurationCache.appendCauseMessages("Failed to decrypt encrypted attribute.", (Exception)((Object)e)), (Throwable)e);
                }
            }
            for (Attribute attribute : statement.getAttributes()) {
                SamlConfigurationCache.extractAttributes(attributes, attribute);
            }
        }
        return attributes;
    }

    private static void extractAttributes(Map<String, List<String>> attributes, Attribute attribute) {
        String name = attribute.getName();
        attributes.computeIfAbsent(name, x -> new ArrayList()).addAll(SamlConfigurationCache.getAllNonNullStringValues(attribute.getAttributeValues()));
    }

    private static String extractValueOfNameId(NameID nameId) {
        if (nameId == null) {
            throw new BadRequestException("IdP returned assertion subject without name ID!");
        }
        String value = nameId.getValue();
        if (StringUtils.isEmpty((String)value)) {
            throw new BadRequestException("IdP returned assertion subject with empty name ID!");
        }
        return value;
    }

    private static List<String> getAllNonNullStringValues(List<XMLObject> values) {
        return values.stream().filter(Objects::nonNull).map(XMLObject::getDOM).map(Objects::requireNonNull).map(Node::getFirstChild).filter(Objects::nonNull).map(Node::getNodeValue).filter(Predicate.not(StringUtils::isEmpty)).toList();
    }

    @VisibleForTesting
    static String appendCauseMessages(String message, Exception e) {
        return Stream.concat(Stream.of(message), Stream.iterate(e, t -> t.getCause() != null, Throwable::getCause).map(Throwable::getMessage)).distinct().collect(Collectors.joining("\n"));
    }

    synchronized void clearCaches() {
        this.ssoLocationCache.clear();
    }

    private static class RequestEncodingHelper
    extends HTTPRedirectDeflateEncoder {
        private RequestEncodingHelper() {
        }

        public @NonNull String deflateAndBase64Encode(@NonNull SAMLObject message) throws MessageEncodingException {
            return super.deflateAndBase64Encode(message);
        }
    }
}

