/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "PublicKeyPinningService.h"

#include "RootCertificateTelemetryUtils.h"
#include "mozilla/ArrayUtils.h"
#include "mozilla/Base64.h"
#include "mozilla/BinarySearch.h"
#include "mozilla/Casting.h"
#include "mozilla/Logging.h"
#include "mozilla/Span.h"
#include "mozilla/StaticPrefs_security.h"
#include "mozilla/Telemetry.h"
#include "nsDependentString.h"
#include "nsServiceManagerUtils.h"
#include "nsSiteSecurityService.h"
#include "mozpkix/pkixtypes.h"
#include "mozpkix/pkixutil.h"
#include "seccomon.h"
#include "sechash.h"

#include "StaticHPKPins.h"  // autogenerated by genHPKPStaticpins.js

using namespace mozilla;
using namespace mozilla::pkix;
using namespace mozilla::psm;

LazyLogModule gPublicKeyPinningLog("PublicKeyPinningService");

NS_IMPL_ISUPPORTS(PublicKeyPinningService, nsIPublicKeyPinningService)

enum class PinningMode : uint32_t {
  Disabled = 0,
  AllowUserCAMITM = 1,
  Strict = 2,
  EnforceTestMode = 3
};

PinningMode GetPinningMode() {
  PinningMode pinningMode = static_cast<PinningMode>(
      StaticPrefs::security_cert_pinning_enforcement_level_DoNotUseDirectly());
  switch (pinningMode) {
    case PinningMode::Disabled:
      return PinningMode::Disabled;
    case PinningMode::AllowUserCAMITM:
      return PinningMode::AllowUserCAMITM;
    case PinningMode::Strict:
      return PinningMode::Strict;
    case PinningMode::EnforceTestMode:
      return PinningMode::EnforceTestMode;
    default:
      return PinningMode::Disabled;
  }
}

/**
 Computes in the location specified by base64Out the SHA256 digest
 of the DER Encoded subject Public Key Info for the given cert
*/
static nsresult GetBase64HashSPKI(const BackCert& cert,
                                  nsACString& hashSPKIDigest) {
  Input derPublicKey = cert.GetSubjectPublicKeyInfo();

  hashSPKIDigest.Truncate();
  nsTArray<uint8_t> digestArray;
  nsresult nsrv =
      Digest::DigestBuf(SEC_OID_SHA256, derPublicKey.UnsafeGetData(),
                        derPublicKey.GetLength(), digestArray);
  if (NS_FAILED(nsrv)) {
    return nsrv;
  }
  return Base64Encode(nsDependentCSubstring(
                          BitwiseCast<char*, uint8_t*>(digestArray.Elements()),
                          digestArray.Length()),
                      hashSPKIDigest);
}

/*
 * Sets certMatchesPinset to true if a given cert matches any fingerprints from
 * the given pinset and false otherwise.
 */
static nsresult EvalCert(const BackCert& cert,
                         const StaticFingerprints* fingerprints,
                         /*out*/ bool& certMatchesPinset) {
  certMatchesPinset = false;
  if (!fingerprints) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
            ("pkpin: No hashes found\n"));
    return NS_ERROR_INVALID_ARG;
  }

  nsAutoCString base64Out;
  nsresult rv = GetBase64HashSPKI(cert, base64Out);
  if (NS_FAILED(rv)) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
            ("pkpin: GetBase64HashSPKI failed!\n"));
    return rv;
  }

  if (fingerprints) {
    for (size_t i = 0; i < fingerprints->size; i++) {
      if (base64Out.Equals(fingerprints->data[i])) {
        MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
                ("pkpin: found pin base_64 ='%s'\n", base64Out.get()));
        certMatchesPinset = true;
        return NS_OK;
      }
    }
  }
  return NS_OK;
}

/*
 * Sets certListIntersectsPinset to true if a given chain matches any
 * fingerprints from the given static fingerprints and false otherwise.
 */
static nsresult EvalChain(const nsTArray<Span<const uint8_t>>& derCertList,
                          const StaticFingerprints* fingerprints,
                          /*out*/ bool& certListIntersectsPinset) {
  certListIntersectsPinset = false;
  if (!fingerprints) {
    MOZ_ASSERT(false, "Must pass in at least one type of pinset");
    return NS_ERROR_FAILURE;
  }

  EndEntityOrCA endEntityOrCA = EndEntityOrCA::MustBeEndEntity;
  for (const auto& cert : derCertList) {
    Input certInput;
    mozilla::pkix::Result rv = certInput.Init(cert.data(), cert.size());
    if (rv != mozilla::pkix::Result::Success) {
      return NS_ERROR_INVALID_ARG;
    }
    BackCert backCert(certInput, endEntityOrCA, nullptr);
    rv = backCert.Init();
    if (rv != mozilla::pkix::Result::Success) {
      return NS_ERROR_INVALID_ARG;
    }

    nsresult nsrv = EvalCert(backCert, fingerprints, certListIntersectsPinset);
    if (NS_FAILED(nsrv)) {
      return nsrv;
    }
    if (certListIntersectsPinset) {
      break;
    }
    endEntityOrCA = EndEntityOrCA::MustBeCA;
  }

  if (!certListIntersectsPinset) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
            ("pkpin: no matches found\n"));
  }
  return NS_OK;
}

class TransportSecurityPreloadBinarySearchComparator {
 public:
  explicit TransportSecurityPreloadBinarySearchComparator(
      const char* aTargetHost)
      : mTargetHost(aTargetHost) {}

  int operator()(const TransportSecurityPreload& val) const {
    return strcmp(mTargetHost, val.mHost);
  }

 private:
  const char* mTargetHost;  // non-owning
};

#ifdef DEBUG
static Atomic<bool> sValidatedPinningPreloadList(false);

static void ValidatePinningPreloadList() {
  if (sValidatedPinningPreloadList) {
    return;
  }
  for (const auto& entry : kPublicKeyPinningPreloadList) {
    // If and only if a static entry is a Mozilla entry, it has a telemetry ID.
    MOZ_ASSERT((entry.mIsMoz && entry.mId != kUnknownId) ||
               (!entry.mIsMoz && entry.mId == kUnknownId));
  }
  sValidatedPinningPreloadList = true;
}
#endif  // DEBUG

// Returns via one of the output parameters the most relevant pinning
// information that is valid for the given host at the given time.
static nsresult FindPinningInformation(
    const char* hostname, mozilla::pkix::Time time,
    /*out*/ const TransportSecurityPreload*& staticFingerprints) {
#ifdef DEBUG
  ValidatePinningPreloadList();
#endif
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }
  staticFingerprints = nullptr;
  const TransportSecurityPreload* foundEntry = nullptr;
  const char* evalHost = hostname;
  const char* evalPart;
  // Notice how the (xx = strchr) prevents pins for unqualified domain names.
  while (!foundEntry && (evalPart = strchr(evalHost, '.'))) {
    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
            ("pkpin: Querying pinsets for host: '%s'\n", evalHost));
    size_t foundEntryIndex;
    if (BinarySearchIf(kPublicKeyPinningPreloadList, 0,
                       std::size(kPublicKeyPinningPreloadList),
                       TransportSecurityPreloadBinarySearchComparator(evalHost),
                       &foundEntryIndex)) {
      foundEntry = &kPublicKeyPinningPreloadList[foundEntryIndex];
      MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
              ("pkpin: Found pinset for host: '%s'\n", evalHost));
      if (evalHost != hostname) {
        if (!foundEntry->mIncludeSubdomains) {
          // Does not apply to this host, continue iterating
          foundEntry = nullptr;
        }
      }
    } else {
      MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
              ("pkpin: Didn't find pinset for host: '%s'\n", evalHost));
    }
    // Add one for '.'
    evalHost = evalPart + 1;
  }

  if (foundEntry && foundEntry->pinset) {
    if (time > TimeFromEpochInSeconds(kPreloadPKPinsExpirationTime /
                                      PR_USEC_PER_SEC)) {
      return NS_OK;
    }
    staticFingerprints = foundEntry;
  }
  return NS_OK;
}

// Returns true via the output parameter if the given certificate list meets
// pinning requirements for the given host at the given time. It must be the
// case that either there is an intersection between the set of hashes of
// subject public key info data in the list and the most relevant non-expired
// pinset for the host or there is no pinning information for the host.
static nsresult CheckPinsForHostname(
    const nsTArray<Span<const uint8_t>>& certList, const char* hostname,
    bool enforceTestMode, mozilla::pkix::Time time,
    /*out*/ bool& chainHasValidPins,
    /*optional out*/ PinningTelemetryInfo* pinningTelemetryInfo) {
  chainHasValidPins = false;
  if (certList.IsEmpty()) {
    return NS_ERROR_INVALID_ARG;
  }
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }

  const TransportSecurityPreload* staticFingerprints = nullptr;
  nsresult rv = FindPinningInformation(hostname, time, staticFingerprints);
  if (NS_FAILED(rv)) {
    return rv;
  }
  // If we have no pinning information, the certificate chain trivially
  // validates with respect to pinning.
  if (!staticFingerprints) {
    chainHasValidPins = true;
    return NS_OK;
  }
  if (staticFingerprints) {
    bool enforceTestModeResult;
    rv = EvalChain(certList, staticFingerprints->pinset, enforceTestModeResult);
    if (NS_FAILED(rv)) {
      return rv;
    }
    chainHasValidPins = enforceTestModeResult;
    if (staticFingerprints->mTestMode && !enforceTestMode) {
      chainHasValidPins = true;
    }

    if (pinningTelemetryInfo) {
      // If and only if a static entry is a Mozilla entry, it has a telemetry
      // ID.
      if ((staticFingerprints->mIsMoz &&
           staticFingerprints->mId == kUnknownId) ||
          (!staticFingerprints->mIsMoz &&
           staticFingerprints->mId != kUnknownId)) {
        return NS_ERROR_FAILURE;
      }

      Telemetry::HistogramID histogram;
      int32_t bucket;
      // We can collect per-host pinning violations for this host because it is
      // operationally critical to Firefox.
      if (staticFingerprints->mIsMoz) {
        histogram = staticFingerprints->mTestMode
                        ? Telemetry::CERT_PINNING_MOZ_TEST_RESULTS_BY_HOST
                        : Telemetry::CERT_PINNING_MOZ_RESULTS_BY_HOST;
        bucket = staticFingerprints->mId * 2 + (enforceTestModeResult ? 1 : 0);
      } else {
        histogram = staticFingerprints->mTestMode
                        ? Telemetry::CERT_PINNING_TEST_RESULTS
                        : Telemetry::CERT_PINNING_RESULTS;
        bucket = enforceTestModeResult ? 1 : 0;
      }
      pinningTelemetryInfo->accumulateResult = true;
      pinningTelemetryInfo->certPinningResultHistogram = Some(histogram);
      pinningTelemetryInfo->certPinningResultBucket = bucket;

      // We only collect per-CA pinning statistics upon failures.
      if (!enforceTestModeResult) {
        int32_t binNumber = RootCABinNumber(certList.LastElement());
        if (binNumber != ROOT_CERTIFICATE_UNKNOWN) {
          pinningTelemetryInfo->accumulateForRoot = true;
          pinningTelemetryInfo->rootBucket = binNumber;
        }
      }
    }

    MOZ_LOG(gPublicKeyPinningLog, LogLevel::Debug,
            ("pkpin: Pin check %s for %s host '%s' (mode=%s)\n",
             enforceTestModeResult ? "passed" : "failed",
             staticFingerprints->mIsMoz ? "mozilla" : "non-mozilla", hostname,
             staticFingerprints->mTestMode ? "test" : "production"));
  }

  return NS_OK;
}

nsresult PublicKeyPinningService::ChainHasValidPins(
    const nsTArray<Span<const uint8_t>>& certList, const char* hostname,
    mozilla::pkix::Time time, bool isBuiltInRoot,
    /*out*/ bool& chainHasValidPins,
    /*optional out*/ PinningTelemetryInfo* pinningTelemetryInfo) {
  PinningMode pinningMode(GetPinningMode());
  if (pinningMode == PinningMode::Disabled ||
      (!isBuiltInRoot && pinningMode == PinningMode::AllowUserCAMITM)) {
    chainHasValidPins = true;
    return NS_OK;
  }

  chainHasValidPins = false;
  if (certList.IsEmpty()) {
    return NS_ERROR_INVALID_ARG;
  }
  if (!hostname || hostname[0] == 0) {
    return NS_ERROR_INVALID_ARG;
  }
  nsAutoCString canonicalizedHostname(CanonicalizeHostname(hostname));
  bool enforceTestMode = pinningMode == PinningMode::EnforceTestMode;
  return CheckPinsForHostname(certList, canonicalizedHostname.get(),
                              enforceTestMode, time, chainHasValidPins,
                              pinningTelemetryInfo);
}

NS_IMETHODIMP
PublicKeyPinningService::HostHasPins(nsIURI* aURI, bool* hostHasPins) {
  NS_ENSURE_ARG(aURI);
  NS_ENSURE_ARG(hostHasPins);
  *hostHasPins = false;
  PinningMode pinningMode(GetPinningMode());
  if (pinningMode == PinningMode::Disabled) {
    return NS_OK;
  }
  nsAutoCString hostname;
  nsresult rv = nsSiteSecurityService::GetHost(aURI, hostname);
  if (NS_FAILED(rv)) {
    return rv;
  }
  if (nsSiteSecurityService::HostIsIPAddress(hostname)) {
    return NS_OK;
  }

  const TransportSecurityPreload* staticFingerprints = nullptr;
  rv = FindPinningInformation(hostname.get(), Now(), staticFingerprints);
  if (NS_FAILED(rv)) {
    return rv;
  }
  if (staticFingerprints) {
    *hostHasPins = !staticFingerprints->mTestMode ||
                   pinningMode == PinningMode::EnforceTestMode;
  }
  return NS_OK;
}

nsAutoCString PublicKeyPinningService::CanonicalizeHostname(
    const char* hostname) {
  nsAutoCString canonicalizedHostname(hostname);
  ToLowerCase(canonicalizedHostname);
  while (canonicalizedHostname.Length() > 0 &&
         canonicalizedHostname.Last() == '.') {
    canonicalizedHostname.Truncate(canonicalizedHostname.Length() - 1);
  }
  return canonicalizedHostname;
}
