#!/usr/bin/env python

import os, sys
import glob
import numpy as np
import time
import calendar
from datetime import datetime, timedelta
import inspect
from netCDF4 import Dataset
from gridded_area import do_grid_p as dgp
import gc

# ============= READ ANT =============
def readant(fname,outheight):
    nc = Dataset(fname, mode='r')
    dom = nc.variables['emis_dom'][:]  # kg/m2/h
    ene = nc.variables['emis_ene'][:]  # kg/m2/h
    flr = nc.variables['emis_flr'][:]  # kg/m2/h
    ind = nc.variables['emis_ind'][:]  # kg/m2/h
    shp = nc.variables['emis_shp'][:]  # kg/m2/h
    tra = nc.variables['emis_tra'][:]  # kg/m2/h
    wst = nc.variables['emis_wst'][:]  # kg/m2/h
    return dom*1.e12/(3600*outheight[0]), ene*1.e12/(3600*outheight[0]), flr*1.e12/(3600*outheight[0]), ind*1.e12/(3600*outheight[0]),  shp*1.e12/(3600*outheight[0]), tra*1.e12/(3600*outheight[0]), wst*1.e12/(3600*outheight[0]) # kg/m2/h * 1.e12 / (3600 sec/h *m) == ng/m3/s

# ============= READ BB GFED4 ============
def readbb(fname,outheight):
    nc = Dataset(fname, mode='r')
    bbemi = nc.variables['bbemi'][:]    # g m-2 d-1
    return bbemi*1.e9/(86400*outheight[-2])  # g/m2/d * 1.e9 / (86400 sec/d *m) == ng/m3/s
    # BB takes higher altitudes to account for large fires

# ============= READ BB CAMS ============
def readbb_cams(fname, outheight, ssp):
    nc = Dataset(fname, mode='r')
    bbemi = nc.variables[ssp+'fire'][:]    # kg m-2 s-1 for daily steps
    bbemi = bbemi[:,::-1,:]         # Reverse latitude
    apt = nc.variables['apt'][:]    # injection in meters
    apt = apt[:,::-1,:]             # Reverse latitude
    # Find indices where var2 is above or below 3 km
    print('Find indices')
    indices1 = np.where(apt <= 3000)
    indices2 = np.where(apt > 3000)

    # Create a new array filled with zeros
    print('Create new arrays filled with zeros')
    bbemi1 = np.zeros_like(bbemi)
    bbemi2 = np.zeros_like(bbemi)

    # Assign values from var1 where var2 is above or below 3 km
    print('Assign values to new arrays')
    bbemi1[indices1] = bbemi[indices1]
    bbemi2[indices2] = bbemi[indices2]

    return bbemi1*1.e12/outheight[-2], bbemi2*1.e12/(outheight[-1] - outheight[-2])  # kg/m2/s * 1.e12 / (m) == ng/m3/s

# ============= READ BB FINN ============
def readbb_finn(fname, outheight):
    nc = Dataset(fname, mode='r')
    bbemi = nc.variables['bbemi'][:]    # kg m-2 s-1 for daily steps
    bbemi = bbemi[:,:,:]             #
    return bbemi*1.e12/outheight[-2]  # kg/m2/s * 1.e12 / (m) == ng/m3/s

def Get_data_bwd(fdir,bb,spout = '001'):
    nc = Dataset(glob.glob(fdir + "/grid_*.nc")[0], mode='r')
    lons = nc.variables['longitude'][:]
    lats = nc.variables['latitude'][:]
    time = nc.variables['time'][:]
    outheight = nc.variables['height'][:]             #meters

    dxout = nc.getncattr("dxout")
    dyout = nc.getncattr("dyout")
    iedate = nc.getncattr("iedate")   # GET END DATE FOR WHICH TIMESTEPS WILL GO BACK IN TIME
    ietime = nc.getncattr("ietime")   # GET END TIME FOR WHICH TIMESTEPS WILL GO BACK IN TIME
    ind_rec = nc.getncattr("ind_receptor")
    ind_src = nc.getncattr("ind_source")
    ireleaseend = nc.variables['RELEND'][:] # GET END DATE OF RELEASES WITH RESPECT TO IBDATE (NEGATIVE VALUES)
    ireleasestart = nc.variables['RELSTART'][:]     # GET START DATE OF RELEASES WITH RESPECT TO IBDATE (NEGATIVE VALUES)
    kindp = nc.variables['RELKINDZ'][:]
    lage = nc.variables['LAGE'][:]
    lconvection = nc.getncattr("lconvection")
    loutaver = nc.getncattr("loutaver")
    loutsample = nc.getncattr("loutsample")
    loutstep = nc.getncattr("loutstep")
    lsubgrid = nc.getncattr("lsubgrid")
    outlon0 = nc.getncattr("outlon0")
    outlat0 = nc.getncattr("outlat0")
    nxgrid = len(lons)
    nygrid = len(lats)
    nzgrid = len(outheight)
    npart = nc.variables['RELPART'][:]
    xmass = nc.variables['RELXMASS'][:]
    nageclass = nc.dimensions['nageclass'].size
    numpoint = nc.dimensions['numpoint'].size
    xpoint1 = nc.variables['RELLNG1'][:]
    xpoint2 = nc.variables['RELLNG2'][:]
    ypoint1 = nc.variables['RELLAT1'][:]
    ypoint2 = nc.variables['RELLAT2'][:]
    zpoint1 = nc.variables['RELZZ1'][:]
    zpoint2 = nc.variables['RELZZ2'][:]

    print('Outheight is always a list ',outheight)
    nageidx = nageclass-1
    nzgrid = 0
    
    ctmp = []; species = []; ctime = []
    for key in nc.groups.keys():
        print("Reading release: "+key+ " out of "+ str(numpoint) )
        #ctmp.append(nc[key].variables['spec'+spout][nageidx,:,nzgrid,:,:] )
        ctmp.append(nc[key].variables['spec'+spout][nageidx,:,:,:,:] )	# Take all heights to use in BB
        print("Shape of array", np.shape(nc[key].variables['spec'+spout][nageidx,:,:,:,:] ))
        ctime.append(nc[key].variables['time'][:] )
        
        species.append(nc[key].variables['spec'+spout].long_name)

    print('SPecies to be plotted is ',species[0].lower())
    print('ind_src & ind_rec are ', ind_src, ' ',ind_rec)

    # READ EMISSIONS
    sim_endtime = datetime.strptime(str(iedate)+str(ietime).zfill(6), "%Y%m%d%H%M%S")
    rld0 = [sim_endtime + timedelta(seconds=int(ii)) for ii in ireleasestart]
    rld1 = [sim_endtime + timedelta(seconds=int(ii)) for ii in ireleaseend]; print('rld1',rld1[0],rld1[-1])
    
    edom = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[0]
    eene = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[1]
    eflr = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[2]
    eind = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[3]
    eshp = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[4]
    etra = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[5]
    ewst = readant('/xnilu_wrk/users/ne/BC_INVENTORIES/ECLIPSE_ANT/V6b/ECLIPSE_V6b_CLE_base_BC_2020-3D_noAWB.nc',outheight)[6]
    if bb == 'GFED':
        ebb1 = readbb('/xnilu_wrk/users/ne/GFEDv4/720x360_daily/GFED4.1s_'+species[0]+'_'+str(rld1[-2].year)+'_720x360_daily.nc',outheight); print('GFED4 year',rld1[-2].year)
        ebb2 = np.zeros((ebb1.shape[0],ebb1.shape[1],ebb1.shape[2]))
    elif bb == 'CAMS':
        print('CAMS year',rld1[-2].year)
        ebb1,ebb2 = readbb_cams('/xnilu_wrk/users/ne/BC_INVENTORIES/CAMS_GFAS/720x360/gfas_'+str(rld1[-2].year)+'_720x360_'+species[0]+'.nc', outheight, species[0].lower())
    elif bb == 'FINN':
        ebb1 = readbb_finn('/xnilu_wrk/users/ne/BC_INVENTORIES/FINNv2.5/720x360/FINN_720x360_'+str(rld1[-2].year)+'.nc', outheight); print('FINN year',rld1[-2].year)
        ebb2 = np.zeros((ebb1.shape[0],ebb1.shape[1],ebb1.shape[2]))



    return ctmp,rld0,rld1,ctime,edom,eene,eflr,eind,eshp,ewst,etra,ebb1,ebb2,dxout,dyout,iedate,ietime,ind_rec,ind_src,ireleaseend,ireleasestart,kindp,lage,lconvection,loutaver,loutsample,loutstep,lsubgrid,nageclass,npart,numpoint,nxgrid,nygrid,nzgrid,outlat0,outlon0,xmass,xpoint1,xpoint2,ypoint1,ypoint2,zpoint1,zpoint2,species[0],spout,lons,lats,outheight 
