#include "xc_functional.h"
#include "source_pw/module_pwdft/global.h"
#include "source_io/module_parameter/parameter.h"
#include "source_base/global_function.h"

#ifdef USE_LIBXC
#include "xc_functional_libxc.h"
#endif

XC_Functional::XC_Functional(){}

XC_Functional::~XC_Functional(){}

std::vector<int> XC_Functional::func_id(1);
int XC_Functional::func_type = 0;
bool XC_Functional::ked_flag = false;
bool XC_Functional::use_libxc = true;
double XC_Functional::hybrid_alpha = 0.25;
std::map<int, double> XC_Functional::scaling_factor_xc = { {1, 1.0} }; // added by jghan, 2024-10-10

void XC_Functional::set_hybrid_alpha(const double alpha_in)
{
    hybrid_alpha = alpha_in;
}

void XC_Functional::set_xc_first_loop(const UnitCell& ucell)
{
    /** In the special "two-level" calculation case,
the first scf iteration only calculate the functional without exact
exchange. but in "nscf" calculation, there is no need of "two-level"
method. */
    if (ucell.atoms[0].ncpp.xc_func == "HF" || ucell.atoms[0].ncpp.xc_func == "PBE0" || ucell.atoms[0].ncpp.xc_func == "HSE")
    {
        XC_Functional::set_xc_type("pbe");
    }
    else if ( ucell.atoms[0].ncpp.xc_func == "LC_PBE" || ucell.atoms[0].ncpp.xc_func == "LC_WPBE"
        || ucell.atoms[0].ncpp.xc_func == "LRC_WPBEH" || ucell.atoms[0].ncpp.xc_func == "CAM_PBEH" )
    {
        XC_Functional::set_xc_type("pbe");
    }
    // added by jghan, 2024-07-07
    else if ( ucell.atoms[0].ncpp.xc_func == "MULLER" || ucell.atoms[0].ncpp.xc_func == "POWER"
        || ucell.atoms[0].ncpp.xc_func == "WP22" || ucell.atoms[0].ncpp.xc_func == "CWP22" )
    {
        XC_Functional::set_xc_type("pbe");
    }
    else if (ucell.atoms[0].ncpp.xc_func == "B3LYP")
    {
        XC_Functional::set_xc_type("blyp");
    }
    else if (ucell.atoms[0].ncpp.xc_func == "SCAN0")
    {
        XC_Functional::set_xc_type("scan");
    }
}

// The setting values of functional id according to the index in LIBXC
// for detail, refer to https://www.tddft.org/programs/libxc/functionals/
void XC_Functional::set_xc_type(const std::string xc_func_in)
{
    //Note : due to the separation of gcx_spin and gcc_spin,
    //when you are adding new GGA functionals,
    //please put exchange first, followed by correlation,
    //such as for PBE we have:
    //        func_id.push_back(XC_GGA_X_PBE);
    //        func_id.push_back(XC_GGA_C_PBE);

    func_id.clear();
    scaling_factor_xc.clear(); // added by jghan, 2024-07-07
    std::string xc_func = xc_func_in;
    std::transform(xc_func.begin(), xc_func.end(), xc_func.begin(), (::toupper));
	if( xc_func == "LDA" || xc_func == "PZ" || xc_func == "SLAPZNOGXNOGC") //SLA+PZ
	{
        func_id.push_back(XC_LDA_X);
        func_id.push_back(XC_LDA_C_PZ);
        func_type = 1;
        use_libxc = false;
	}
    else if (xc_func == "PWLDA")
    {
        func_id.push_back(XC_LDA_X);
        func_id.push_back(XC_LDA_C_PW);
        func_type = 1;
        use_libxc = false;
    }
	else if ( xc_func == "PBE" || xc_func == "SLAPWPBXPBC") //PBX+PBC
	{
        func_id.push_back(XC_GGA_X_PBE);
        func_id.push_back(XC_GGA_C_PBE);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "PBESOL") //PBX_S+PBC_S
	{
        func_id.push_back(XC_GGA_X_PBE_SOL);
        func_id.push_back(XC_GGA_C_PBE_SOL);
        func_type = 2;
        use_libxc = false;
	}
	else if( xc_func == "REVPBE" ) //PBX_r+PBC
	{
		func_id.push_back(XC_GGA_X_PBE_R);
        func_id.push_back(XC_GGA_C_PBE);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "WC") //WC+PBC
	{
        func_id.push_back(XC_GGA_X_WC);
        func_id.push_back(XC_GGA_C_PBE);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "BLYP") //B88+LYP
	{
        func_id.push_back(XC_GGA_X_B88);
        func_id.push_back(XC_GGA_C_LYP);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "BP") //B88+P86
	{
        func_id.push_back(XC_GGA_X_B88);
        func_id.push_back(XC_GGA_C_P86);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "PW91") //PW91_X+PW91_C
	{
        func_id.push_back(XC_GGA_X_PW91);
        func_id.push_back(XC_GGA_C_PW91);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "HCTH") //HCTH_X+HCTH_C
	{
        func_id.push_back(XC_GGA_X_HCTH_A);
        func_id.push_back(XC_GGA_C_HCTH_A);
        func_type = 2;
        use_libxc = false;
	}
	else if ( xc_func == "OLYP") //OPTX+LYP
	{
        func_id.push_back(XC_GGA_X_OPTX);
        func_id.push_back(XC_GGA_C_LYP);
        func_type = 2;
        use_libxc = false;
	}
#ifdef USE_LIBXC
	else if ( xc_func == "SCAN")
	{
        func_id.push_back(XC_MGGA_X_SCAN);
        func_id.push_back(XC_MGGA_C_SCAN);
        func_type = 3;
        use_libxc = true;
	}
    else if ( xc_func == "SCAN0")
	{
        func_id.push_back(XC_MGGA_X_SCAN);
        func_id.push_back(XC_MGGA_C_SCAN);
        func_type = 5;
        use_libxc = true;
	}
    else if( xc_func == "LC_PBE")
    {
        func_id.push_back(XC_HYB_GGA_XC_LC_PBEOP);
        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "LC_WPBE")
    {
        func_id.push_back(XC_HYB_GGA_XC_LC_WPBE);
        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "LRC_WPBE")
    {
        func_id.push_back(XC_HYB_GGA_XC_LRC_WPBE);
        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "LRC_WPBEH")
    {
        func_id.push_back(XC_HYB_GGA_XC_LRC_WPBEH);
        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "CAM_PBEH")
    {
        func_id.push_back(XC_HYB_GGA_XC_CAM_PBEH);
        func_type = 4;
        use_libxc = true;
    }
#endif
    else if( xc_func == "HF")
    {
        func_type = 4;
        use_libxc = false;
    }
   	else if( xc_func == "PBE0")
	{
        func_id.push_back(XC_HYB_GGA_XC_PBEH);
        func_type = 4;
        use_libxc = false;
	}
    else if( xc_func == "OPT_ORB" ||  xc_func == "NONE" || xc_func == "NOX+NOC")
    {
        // not doing anything
    }
    else if( xc_func == "MULLER" || xc_func == "POWER" ) // added by jghan, 2024-07-06
    {
        func_type = 4;
        use_libxc = false;
    }
#ifdef USE_LIBXC
    else if( xc_func == "HSE")
    {
        func_id.push_back(XC_HYB_GGA_XC_HSE06);
        func_type = 4;
        use_libxc = true;
    }
    // added by jghan, 2024-07-06
    else if( xc_func == "WP22")
    {
        func_id.push_back(XC_GGA_X_ITYH);   // short-range of B88_X, id=529
        func_id.push_back(XC_GGA_C_LYPR);   // short-range of LYP_C, id=624
        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "CWP22")
    {   
        // BLYP_XC_lr = -BLYP_XC_sr + BLYP_XC, the realization of it is in v_xc_libxc() function, xc_functional_libxc_vxc.cpp
        func_id.push_back(XC_GGA_X_ITYH);   // short-range of B88_X, id=529
        func_id.push_back(XC_GGA_C_LYPR);   // short-range of LYP_C, id=624
        func_id.push_back(XC_GGA_X_B88);    // complete B88_X, id=106
        func_id.push_back(XC_GGA_C_LYP);    // complete LYP_C, id=131

        // the scaling factor of CWP22-functionals
        scaling_factor_xc[XC_GGA_X_ITYH] = -1.0;
        scaling_factor_xc[XC_GGA_C_LYPR] = -1.0;
        scaling_factor_xc[XC_GGA_X_B88] = 1.0;
        scaling_factor_xc[XC_GGA_X_B88] = 1.0;

        func_type = 4;
        use_libxc = true;
    }
    else if( xc_func == "BLYP_LR")
    {   
        // BLYP_XC_lr = -BLYP_XC_sr + BLYP_XC, the realization of it is in v_xc_libxc() function, xc_functional_libxc_vxc.cpp
        func_id.push_back(XC_GGA_X_ITYH);   // short-range of B88_X, id=529
        func_id.push_back(XC_GGA_C_LYPR);   // short-range of LYP_C, id=624
        func_id.push_back(XC_GGA_X_B88);    // complete B88_X, id=106
        func_id.push_back(XC_GGA_C_LYP);    // complete LYP_C, id=131

        // the scaling factor of BLYP_LR-functionals
        scaling_factor_xc[XC_GGA_X_ITYH] = -1.0;
        scaling_factor_xc[XC_GGA_C_LYPR] = -1.0;
        scaling_factor_xc[XC_GGA_X_B88] = 1.0;
        scaling_factor_xc[XC_GGA_X_B88] = 1.0;

        func_type = 2;
        use_libxc = true;
    }
    else if (xc_func == "B3LYP")
    {
        func_id.push_back(XC_HYB_GGA_XC_B3LYP);
        func_type = 4;
        use_libxc = true;
    }
#endif
    else
    {
#ifdef USE_LIBXC
        //see if it matches libxc functionals
        const std::pair<int,std::vector<int>> type_id = XC_Functional_Libxc::set_xc_type_libxc(xc_func);
        func_type = std::get<0>(type_id);
        func_id = std::get<1>(type_id);
        use_libxc = true;
#else
        std::string message = "Unrecognized exchange-correlation functional '"+ xc_func +"'.\n"
                              " Possible source: Pseudopotential file or dft_functional parameter.\n"
                              " Please explicitly set dft_functional in INPUT,\n"
                              " or verify the functional name is supported.";
        ModuleBase::WARNING_QUIT("xc_functional.cpp",message);
#endif
    }

    if (func_type == 3 || func_type == 5)
    {
        ked_flag = true;
    }
    else
    {
        ked_flag = false;
    }

    if (func_id[0] == XC_GGA_X_OPTX)
    {
        std::cerr << "\n OPTX untested please test,";
    }

    // if((func_type == 4 || func_type == 5) && PARAM.inp.basis_type == "pw")
    // {
    //     ModuleBase::WARNING_QUIT("set_xc_type","hybrid functional not realized for planewave yet");
    // }
    if((func_type == 3 || func_type == 5) && PARAM.inp.nspin==4)
    {
        ModuleBase::WARNING_QUIT("set_xc_type","meta-GGA has not been implemented for nspin = 4 yet");
    }

#ifndef __EXX
    if((func_type == 4 || func_type == 5) && PARAM.inp.basis_type == "lcao")
    {
        ModuleBase::WARNING_QUIT("set_xc_type","compile with libri to use hybrid functional in lcao basis");
    }
#endif

#ifndef USE_LIBXC
    if(xc_func == "SCAN" || xc_func == "HSE" || xc_func == "SCAN0" 
        || xc_func == "MULLER" || xc_func == "POWER" || xc_func == "WP22" || xc_func == "CWP22" ||
        xc_func == "LC_PBE" || xc_func == "LC_WPBE" || xc_func == "LRC_WPBE" ||
        xc_func == "LRC_PBEH" || xc_func == "CAM_PBEH")
    {
        ModuleBase::WARNING_QUIT("set_xc_type","to use SCAN, SCAN0, HSE, long-range corrected (LC_PBE, LC_WPBE...) or CAM_PBEH LIBXC is required");
    }
    use_libxc = false;
#endif

}

std::string XC_Functional::output_info()
{
  #ifdef USE_LIBXC
    if(use_libxc)
    {
        std::stringstream ss;
        ss<<" Libxc v"<<xc_version_string()<<std::endl;
        ss<<"\t"<<xc_reference()<<std::endl;

        std::vector<xc_func_type> funcs = XC_Functional_Libxc::init_func(func_id, XC_UNPOLARIZED);
        for(const auto &func : funcs)
        {
            const xc_func_info_type *info = xc_func_get_info(&func);
            ss<<" XC: "<<xc_func_info_get_name(info)<<std::endl;
            for(int i=0; i<XC_MAX_REFERENCES; ++i)
            {
                const func_reference_type *ref = xc_func_info_get_references(func.info, i);
                if(ref)
                    ss<<"\t"<<xc_func_reference_get_ref(ref)<<std::endl;
            }
        }
        XC_Functional_Libxc::finish_func(funcs);
        return ss.str();
    }
    else
    {
        std::string s = " XC:\t";
        for(const auto &id: func_id)
            s += std::string(xc_functional_get_name(id))+"\t";
        return s;
    }
  #else
    std::string s = " XC:\t";
    for(const auto &id: func_id)
        s += std::to_string(id)+"\t";
    return s;
  #endif
}
