PKCS7 Padding

P
using System;

namespace Algorithms.Crypto.Paddings;

/// <summary>
/// <para>
/// This class implements the PKCS7 padding scheme, which is a standard way of padding data to fit a certain block size.
/// </para>
/// <para>
/// PKCS7 padding adds N bytes of value N to the end of the data, where N is the number of bytes needed to reach the block size.
/// For example, if the block size is 16 bytes, and the data is 11 bytes long, then 5 bytes of value 5 will be added to the
/// end of the data. This way, the padded data will be 16 bytes long and can be encrypted or decrypted by a block cipher algorithm.
/// </para>
/// <para>
/// The padding can be easily removed after decryption by looking at the last byte and subtracting that many bytes from the
/// end of the data.
/// </para>
/// <para>
/// This class supports any block size from 1 to 255 bytes, and can be used with any encryption algorithm that requires
/// padding, such as AES.
/// </para>
/// </summary>
public class Pkcs7Padding : IBlockCipherPadding
{
    private readonly int blockSize;

    public Pkcs7Padding(int blockSize)
    {
        if (blockSize is < 1 or > 255)
        {
            throw new ArgumentOutOfRangeException(nameof(blockSize), $"Invalid block size: {blockSize}");
        }

        this.blockSize = blockSize;
    }

    /// <summary>
    /// Adds padding to the end of a byte array according to the PKCS#7 standard.
    /// </summary>
    /// <param name="input">The byte array to be padded.</param>
    /// <param name="inputOffset">The offset from which to start padding.</param>
    /// <returns>The padding value that was added to each byte.</returns>
    /// <exception cref="ArgumentException">
    /// If the input array does not have enough space to add <c>blockSize</c> bytes as padding.
    /// </exception>
    /// <remarks>
    /// The padding value is equal to the number of of bytes that are added to the array.
    /// For example, if the input array has a length of 16 and the input offset is 10,
    /// then 6 bytes with the value 6 will be added to the end of the array.
    /// </remarks>
    public int AddPadding(byte[] input, int inputOffset)
    {
        // Calculate how many bytes need to be added to reach the next multiple of block size.
        var code = (byte)((blockSize - (input.Length % blockSize)) % blockSize);

        // If no padding is needed, add a full block of padding.
        if (code == 0)
        {
            code = (byte)blockSize;
        }

        if (inputOffset + code > input.Length)
        {
            throw new ArgumentException("Not enough space in input array for padding");
        }

        // Add the padding
        for (var i = 0; i < code; i++)
        {
            input[inputOffset + i] = code;
        }

        return code;
    }

    /// <summary>
    /// Removes the PKCS7 padding from the given input data.
    /// </summary>
    /// <param name="input">The input data with PKCS7 padding. Must not be null and must have a vaild length and padding.</param>
    /// <returns>The input data without the padding as a new byte array.</returns>
    /// <exception cref="ArgumentException">
    /// Thrown if the input data is null, has an invalid length, or has an invalid padding.
    /// </exception>
    public byte[] RemovePadding(byte[] input)
    {
        // Check if input length is a multiple of blockSize
        if (input.Length % blockSize != 0)
        {
            throw new ArgumentException("Input length must be a multiple of block size");
        }

        // Get the padding length from the last byte of input
        var paddingLength = input[^1];

        // Check if padding length is valid
        if (paddingLength < 1 || paddingLength > blockSize)
        {
            throw new ArgumentException("Invalid padding length");
        }

        // Check if all padding bytes have the correct value
        for (var i = 0; i < paddingLength; i++)
        {
            if (input[input.Length - 1 - i] != paddingLength)
            {
                throw new ArgumentException("Invalid padding");
            }
        }

        // Create a new array with the size of input minus the padding length
        var output = new byte[input.Length - paddingLength];

        // Copy the data without the padding into the output array
        Array.Copy(input, output, output.Length);

        return output;
    }

    /// <summary>
    /// Gets the number of padding bytes in the given input data according to the PKCS7 padding scheme.
    /// </summary>
    /// <param name="input">The input data with PKCS7 padding. Must not be null and must have a valid padding.</param>
    /// <returns>The number of padding bytes in the input data.</returns>
    /// <exception cref="ArgumentException">
    /// Thrown if the input data is null or has an invalid padding.
    /// </exception>
    /// <remarks>
    /// This method uses bitwise operations to avoid branching.
    /// </remarks>
    public int GetPaddingCount(byte[] input)
    {
        if (input == null)
        {
            throw new ArgumentNullException(nameof(input), "Input cannot be null");
        }

        // Get the last byte of the input data as the padding value.
        var lastByte = input[^1];
        var paddingCount = lastByte & 0xFF;

        // Calculate the index where the padding starts
        var paddingStartIndex = input.Length - paddingCount;
        var paddingCheckFailed = 0;

        // Check if the padding start index is negative or greater than the input length.
        // This is done by using bitwise operations to avoid branching.
        // If the padding start index is negative, then its most significant bit will be 1.
        // If the padding count is greater than the block size, then its most significant bit will be 1.
        // By ORing these two cases, we can get a non-zero value rif either of them is true.
        // By shifting this value right by 31 bits, we can get either 0 or -1 as the result.
        paddingCheckFailed = (paddingStartIndex | (paddingCount - 1)) >> 31;

        for (var i = 0; i < input.Length; i++)
        {
            // Check if each byte matches the padding value.
            // This is done by using bitwise operations to avoid branching.
            // If a byte does not match the padding value, then XORing them will give a non-zero value.
            // If a byte is before the padding start index, then we want to ignore it.
            // This is done by using bitwise operations to create a mask that is either all zeros or all ones.
            // If i is less than the padding start index, then subtracting them will give a negative value.
            // By shifting this value right by 31 bits, we can get either -1 or 0 as the mask.
            // By negating this mask, we can get either 0 or -1 as the mask.
            // By ANDing this mask with the XOR result, we can get either 0 or the XOR result as the final result.
            // By ORing this final result with the previous padding check result, we can accumulate any non-zero values.
            paddingCheckFailed |= (input[i] ^ lastByte) & ~((i - paddingStartIndex) >> 31);
        }

        // Check if the padding check failed.
        if (paddingCheckFailed != 0)
        {
            throw new ArgumentException("Padding block is corrupted");
        }

        // Return the number of padding bytes.
        return paddingCount;
    }
}