#include <string.h>
#include <openssl/evp.h>

#include "crypto.h"

using namespace std;

int aes_gcm_encrypt(const sgx_aes_gcm_128bit_key_t *p_key, const uint8_t *p_src, uint32_t src_len,
                    uint8_t *p_dst, const uint8_t *p_iv, uint32_t iv_len, const uint8_t *p_aad, uint32_t aad_len,
                    sgx_aes_gcm_128bit_tag_t *p_out_mac) {
	if ((src_len >= INT_MAX) || (aad_len >= INT_MAX) || (p_key == NULL) || ((src_len > 0) && (p_dst == NULL)) || ((src_len > 0) && (p_src == NULL))
		|| (p_out_mac == NULL) || (iv_len != SGX_AESGCM_IV_SIZE) || ((aad_len > 0) && (p_aad == NULL))
		|| (p_iv == NULL) || ((p_src == NULL) && (p_aad == NULL)))
	{
		return -1;
	}
	int ret = -1;
	int len = 0;
	EVP_CIPHER_CTX * pState = NULL;

	do {
		// Create and init ctx
		//
		if (!(pState = EVP_CIPHER_CTX_new())) {
			break;
		}

		// Initialise encrypt, key and IV
		//
		if (1 != EVP_EncryptInit_ex(pState, EVP_aes_128_gcm(), NULL, (unsigned char*)p_key, p_iv)) {
			break;
		}

		// Provide AAD data if exist
		//
		if (NULL != p_aad) {
			if (1 != EVP_EncryptUpdate(pState, NULL, &len, p_aad, aad_len)) {
				break;
			}
		}
        if (src_len > 0) {
            // Provide the message to be encrypted, and obtain the encrypted output.
            //
            if (1 != EVP_EncryptUpdate(pState, p_dst, &len, p_src, src_len)) {
                break;
            }
        }
		// Finalise the encryption
		//
		if (1 != EVP_EncryptFinal_ex(pState, p_dst + len, &len)) {
			break;
		}

		// Get tag
		//
		if (1 != EVP_CIPHER_CTX_ctrl(pState, EVP_CTRL_GCM_GET_TAG, SGX_AESGCM_MAC_SIZE, p_out_mac)) {
			break;
		}
		ret = 0;
	} while (0);

	// Clean up and return
	//
	if (pState) {
			EVP_CIPHER_CTX_free(pState);
	}
	return ret;
}

int aes_gcm_decrypt(const sgx_aes_gcm_128bit_key_t *p_key, const uint8_t *p_src, uint32_t src_len, 
                    uint8_t *p_dst, const uint8_t *p_iv, uint32_t iv_len, const uint8_t *p_aad, uint32_t aad_len, 
                    const sgx_aes_gcm_128bit_tag_t *p_in_mac) {
	uint8_t l_tag[SGX_AESGCM_MAC_SIZE];

	if ((src_len >= INT_MAX) || (aad_len >= INT_MAX) || (p_key == NULL) || ((src_len > 0) && (p_dst == NULL)) || ((src_len > 0) && (p_src == NULL))
		|| (p_in_mac == NULL) || (iv_len != SGX_AESGCM_IV_SIZE) || ((aad_len > 0) && (p_aad == NULL))
		|| (p_iv == NULL) || ((p_src == NULL) && (p_aad == NULL))) {
		return -1;
	}
	int len = 0;
	int ret = -1;
	EVP_CIPHER_CTX * pState = NULL;

	// Autenthication Tag returned by Decrypt to be compared with Tag created during seal
	//
	memset(&l_tag, SGX_AESGCM_MAC_SIZE, 0);
	memcpy(l_tag, p_in_mac, SGX_AESGCM_MAC_SIZE);

	do {
		// Create and initialise the context
		//
		if (!(pState = EVP_CIPHER_CTX_new())) {
			break;
		}

		// Initialise decrypt, key and IV
		//
		if (!EVP_DecryptInit_ex(pState, EVP_aes_128_gcm(), NULL, (unsigned char*)p_key, p_iv)) {
			break;
		}

		// Provide AAD data if exist
		//
		if (NULL != p_aad) {
			if (!EVP_DecryptUpdate(pState, NULL, &len, p_aad, aad_len)) {
				break;
			}
		}

		// Decrypt message, obtain the plaintext output
		//
		if (!EVP_DecryptUpdate(pState, p_dst, &len, p_src, src_len)) {
			break;
		}

		// Update expected tag value
		//
		if (!EVP_CIPHER_CTX_ctrl(pState, EVP_CTRL_GCM_SET_TAG, SGX_AESGCM_MAC_SIZE, l_tag)) {
			break;
		}

		// Finalise the decryption. A positive return value indicates success,
		// anything else is a failure - the plaintext is not trustworthy.
		//
		if (EVP_DecryptFinal_ex(pState, p_dst + len, &len) <= 0) {
			break;
		}
		ret = 0;
	} while (0);

	// Clean up and return
	//
	if (pState != NULL) {
		EVP_CIPHER_CTX_free(pState);
	}
	memset(&l_tag, SGX_AESGCM_MAC_SIZE, 0);
	return ret;
}