#!/usr/bin/env python3

import os,glob,sys
import numpy as np
from netCDF4 import Dataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.basemap import Basemap, cm
import time
import calendar
from datetime import datetime,date, timedelta 
import matplotlib.dates as mdates
from matplotlib import dates,gridspec
import pandas as pd

# ========== HARD-CODED: ONLY CHANGE MAIN.py ==========
from MAIN import imp
if imp()[3]:
    fdir, dirstat, stat, precip, precfile, bb = imp()
    print('1', 'L25')
else:
    fdir, dirstat, stat, precip, bb = imp()
    print('2', 'L28')

from get_data_src import Get_data_bwd  # get_data_ant first multiplies output footprint and emissions and then summs them
print('Precipitation is', precip)

# ============= READ PRECIPITATION ===============
def prreader(fname):
    d1 = []; d2 = []; lx1 = []; ly1 = []; pr = []
    for i, line in enumerate(open(fname,'r')):
        if i > 0:   # line number
            line = line.split()
            if len(line) > 0:
#         if "<" not in line[colnum]:
                lx1.append(float(line[2]))
                ly1.append(float(line[3]))
                pr.append(float(line[5])) # in mm of water eq. NOTE: take prec(accmlt) from ECMWF; the other one is from input
                d1.append(datetime.strptime(line[1], "%Y%m%d%H"))
                d2.append(datetime.strptime(line[0], "%Y%m%d%H"))
    return np.array(pr)

if precip:      # If this is set to true in MAIN then read precipitation and create ng/g of snow
    pr = prreader(precfile)

dd = []; obs = []
for i, line in enumerate(open('eBC_NO0002_June_2023_For_SEC_NE_MD_Report_19_05_2024.txt','r')):
    if i > 2:   # line number
        line = line.split()
        if len(line) > 0:
            dd.append(datetime.strptime(line[0]+line[1], "%d/%m/%Y%H:%M"))
            obs.append(float(line[4]))

print(len(obs))
#============== READ source =============
ctmp,dat,datE,ctime,edom,eene,eflr,eind,eshp,ewst,etra,ebb1,ebb2,dxout,dyout,iedate,ietime,ind_rec,ind_src,ireleaseend,ireleasestart,kindp,lage,lconvection,loutaver,loutsample,loutstep,lsubgrid,nageclass,npart,numpoint,nxgrid,nygrid,nzgrid,outlat0,outlon0,xmass,xpoint1,xpoint2,ypoint1,ypoint2,zpoint1,zpoint2,species,spout,lons,lats,outheight = Get_data_bwd(fdir,bb,'001')

# Calculate weights for BB
wei1 = outheight[0]/outheight[1]
wei2 = (outheight[1]-outheight[0])/outheight[1]
print("Weights for footprint that will be used in BB ", wei1, wei2)

counter = 1
nzgrid = 0  # to manually get the lowest level
nageidx = nageclass-1

cdom = np.zeros((0))
cene = np.zeros((0))
cflr = np.zeros((0))
cind = np.zeros((0))
cshp = np.zeros((0))
cwst = np.zeros((0))
ctra = np.zeros((0))
cbb = np.zeros((0))

# === LOOP ===
for ii in range(len(ctmp)):     #enumerate(sorted(glob.glob(fdir + "/grid_time_*_" + spout)) ):        # FOR BACKWARDS
    print ("Calculating concentrations for release "+str(ii)+ " out of "+str(len(ctmp)) )
    print("RelSTART", dat[ii], "RelEND", datE[ii])

    cc1 = np.zeros((nygrid,nxgrid))
    cc2 = np.zeros((nygrid,nxgrid))
    cc3 = np.zeros((nygrid,nxgrid))
    cc4 = np.zeros((nygrid,nxgrid))
    cc5 = np.zeros((nygrid,nxgrid))
    cc6 = np.zeros((nygrid,nxgrid))
    cc7 = np.zeros((nygrid,nxgrid))
    cc8 = np.zeros((nygrid,nxgrid))
    for j in range(ctmp[ii].shape[0]):  #range(lage[1]/86400):
        dtemp = datE[ii] + timedelta(seconds = float(ctime[ii][j]) )

        if (dtemp.month - 1) < 0:
            pyindex = dtemp.month-1+12; print('DIAGNOSTICS: pyindex for December ', pyindex)
        else:
            pyindex = dtemp.month-1; print('DIAGNOSTICS: pyindex for all months except December ', pyindex)
        pyindex2 = dtemp.timetuple().tm_yday-1; print ("DIAGNOSTICS: index of day from BB is ", pyindex2)

        if (ind_src == 1 and ind_rec == 3) or (ind_src == 1 and ind_rec == 4):
            if precip:
                cc1 += ctmp[ii][j,0,:,:] * edom[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc2 += ctmp[ii][j,0,:,:] * eene[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc3 += ctmp[ii][j,0,:,:] * eflr[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc4 += ctmp[ii][j,0,:,:] * eind[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc5 += ctmp[ii][j,0,:,:] * eshp[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc6 += ctmp[ii][j,0,:,:] * ewst[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                cc7 += ctmp[ii][j,0,:,:] * etra[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                try:
                    cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2,:,:]) + \
                    #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2,:,:] + \
                             (ctmp[ii][j,2,:,:]* ebb2[pyindex2,:,:] )  ) \
                    * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii]   # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g)
                except IndexError:
                    print("Be careful with pyindex2 exception")
                    cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2-1,:,:]) + \
                    #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2-1,:,:] + \
                        (ctmp[ii][j,2,:,:]* ebb2[pyindex2-1,:,:] )  ) \
                    * abs(dat[ii]-datE[ii]).total_seconds() *1.e-3/pr[ii] # m *ng/m3/s * sec = ng/m2 (*1.e-3/pr = ng/g) 
            else:
                cc1 += ctmp[ii][j,0,:,:] * edom[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6# m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc2 += ctmp[ii][j,0,:,:] * eene[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc3 += ctmp[ii][j,0,:,:] * eflr[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc4 += ctmp[ii][j,0,:,:] * eind[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc5 += ctmp[ii][j,0,:,:] * eshp[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc6 += ctmp[ii][j,0,:,:] * ewst[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                cc7 += ctmp[ii][j,0,:,:] * etra[pyindex,:,:] * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                try:
                    cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2,:,:]) + \
                    #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2,:,:] + \
                            (ctmp[ii][j,2,:,:]* ebb2[pyindex2,:,:] )  ) \
                    * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 # m *ng/m3/s * sec = ng/m2 *1.e6 = mg/m2
                except IndexError:
                    print("Be careful with pyindex2 exception")
                    cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2-1,:,:]) + \
                    #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2-1,:,:] + \
                             (ctmp[ii][j,2,:,:]* ebb2[pyindex2-1,:,:] )  ) \
                    * abs(dat[ii]-datE[ii]).total_seconds() *1.e6 
        else:
            cc1 += ctmp[ii][j,0,:,:] * edom[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc2 += ctmp[ii][j,0,:,:] * eene[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc3 += ctmp[ii][j,0,:,:] * eflr[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc4 += ctmp[ii][j,0,:,:] * eind[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc5 += ctmp[ii][j,0,:,:] * eshp[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc6 += ctmp[ii][j,0,:,:] * ewst[pyindex,:,:] # s *ng/m3/s  = ng/m3
            cc7 += ctmp[ii][j,0,:,:] * etra[pyindex,:,:] # s *ng/m3/s  = ng/m3
            try:
                cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2,:,:]) + \
                #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2,:,:] + \
                        (ctmp[ii][j,2,:,:]* ebb2[pyindex2,:,:] )  )# s *ng/m3/s  = ng/m3
            except IndexError:
                print("Be careful with pyindex2 exception")
                cc8 += ( ( np.sum(ctmp[ii][j,0:2,:,:],axis=0) * ebb1[pyindex2-1,:,:]) + \
                #cc8 += ( (ctmp[ii][j,0,:,:]*wei1 + ctmp[ii][j,1,:,:]*wei2) * ebb1[pyindex2-1,:,:] + \
                         (ctmp[ii][j,2,:,:]* ebb2[pyindex2-1,:,:] )  ) # s *ng/m3/s  = ng/m3

        
    cdom = np.concatenate((cdom, cc1.sum()[None]),axis=0); #print(cdom.shape)
    cene = np.concatenate((cene, cc2.sum()[None]),axis=0)
    cflr = np.concatenate((cflr, cc3.sum()[None]),axis=0)
    cind = np.concatenate((cind, cc4.sum()[None]),axis=0)
    cshp = np.concatenate((cshp, cc5.sum()[None]),axis=0)
    cwst = np.concatenate((cwst, cc6.sum()[None]),axis=0)
    ctra = np.concatenate((ctra, cc7.sum()[None]),axis=0)
    cbb = np.concatenate((cbb, cc8.sum()[None]),axis=0)


ctot = cdom+cene+cflr+cind+cshp+cwst+ctra+cbb ; #print(ctot.shape, ctot.sum())


### PLOT ###

Fmt = mdates.DateFormatter('%d-%b-%Y')
# ========= Calculating =========
dtot1 = cdom
dtot2 = cdom+cene
dtot3 = cdom+cene+cflr
dtot4 = cdom+cene+cflr+cind
dtot5 = cdom+cene+cflr+cind+cshp
dtot6 = cdom+cene+cflr+cind+cshp+cwst
dtot7 = cdom+cene+cflr+cind+cshp+cwst+ctra
dtot8 = cdom+cene+cflr+cind+cshp+cwst+ctra+cbb


    #========== PLOTTING ============
    #colorList = ['#f800a9','#fe5000','#fef000','#65fe00','#00fee3','#007cfe']
colorList = ['#c42b37','#397cb4','#6e8172','#c76455','#e6d130','#f480bd','mediumpurple','k']

fig = plt.figure(figsize=(10,6))

gs = gridspec.GridSpec(2, 1,height_ratios=[2,1])
gs00 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs[0])
gs01 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[1])

    #================================
C3=plt.subplot(gs00[0])

C3.fill_between(datE, 0, dtot1, color=colorList[0],interpolate=False)
C3.fill_between(datE, dtot1, dtot2, color=colorList[1],interpolate=False)
C3.fill_between(datE, dtot2, dtot3, color=colorList[2], interpolate=False)
C3.fill_between(datE, dtot3, dtot4, color=colorList[3], interpolate=False)
C3.fill_between(datE, dtot4, dtot5, color=colorList[4], interpolate=False)
C3.fill_between(datE, dtot5, dtot6, color=colorList[5], interpolate=False)
C3.fill_between(datE, dtot6, dtot7, color=colorList[6], interpolate=False)
C3.fill_between(datE, dtot7, dtot8, color=colorList[7], interpolate=False)

C3.plot(dd, obs, color='lime',linewidth=2,linestyle='-',label='eBC observations')

C3.xaxis.set_major_locator(dates.DayLocator(interval=3))
C3.xaxis.set_major_formatter(Fmt)
plt.xticks(rotation=30,ha="right")
#C3.set_xlim((dateA[a],dateA[b]))
#C3.set_ylim((0,np.sum(np.sum(dtot8,axis=1),axis=1).max()))
#C3.set_ylim(0,max(dtot8))
    #C3.set_yticks(np.arange(0, 1501, 200))
    #plt.setp(C3.get_xticklabels(), visible=False)
    #C3.set_yscale('log')

    #C3.set_xscale('log')
if (ind_src == 1 and ind_rec == 3) or (ind_src == 1 and ind_rec == 4):
    if precip:
        C3.set_ylabel('Deposition (ng g'+r'$\mathrm{\mathsf{^{-1}}}$)',fontsize=14)
    else:
        C3.set_ylabel('Deposition (mg m'+r'$\mathrm{\mathsf{^{-2}}}$)',fontsize=14)
    if bb == 'GFED':
        C3.set_title('Source contribution to BC deposition (ECLIPSEv6 and GFED4)')
    elif bb == 'CAMS':
        C3.set_title('Source contribution to BC deposition (ECLIPSEv6 and CAMS)')
    elif bb == 'FINN':
        C3.set_title('Source contribution to BC deposition (ECLIPSEv6 and FINN)')
else:
    C3.set_ylabel('Surface concentration (ng m'+r'$\mathrm{\mathsf{^{-3}}}$)',fontsize=14)
    if bb == 'GFED':
        C3.set_title('Source contribution to BC concentrations (ECLIPSEv6 and GFED4)')
    elif bb == 'CAMS':
        C3.set_title('Source contribution to BC concentrations (ECLIPSEv6 and CAMS)')
    elif bb == 'FINN':
        C3.set_title('Source contribution to BC concentrations (ECLIPSEv6 and FINN)')
C3.legend(loc=0,markerscale=4)
#C3.grid()

#==========================
C4=plt.subplot(gs01[1])
    #C4 = fig.add_subplot(414)

clevs = [1,2,3,4,5,6,7,8]
bounds = [0,1,2,3,4,5,6,7,8]
cmap = mpl.colors.ListedColormap(colorList)
norm = mpl.colors.Normalize(vmin=0, vmax=8)
cb1 = mpl.colorbar.ColorbarBase(C4, cmap=cmap,
                                    norm=norm,
                                    boundaries=bounds,
                                    orientation='horizontal')
cb1.ax.set_xticklabels(['0','1','2','3','4','5','6','7','8'],visible=False)

C4.text(0.07, 0.65, 'DOM',horizontalalignment='center',verticalalignment='top')
C4.text(0.19, 0.65, 'ENE',horizontalalignment='center',verticalalignment='top')
C4.text(0.31, 0.65, 'FLR',horizontalalignment='center',verticalalignment='top')
C4.text(0.44, 0.65, 'IND',horizontalalignment='center',verticalalignment='top')
C4.text(0.56, 0.65, 'SHP',horizontalalignment='center',verticalalignment='top')
C4.text(0.69, 0.65, 'WST',horizontalalignment='center',verticalalignment='top')
C4.text(0.81, 0.65, 'TRA',horizontalalignment='center',verticalalignment='top')
C4.text(0.94, 0.65, 'BB',horizontalalignment='center',verticalalignment='top',color='white')
#plt.tight_layout()

# ======= SAVE  =======

path = "./"
name = "SOURCES_" + species + "_a" + ".png"

print ('Plotting src contr ', path, name)
plt.savefig(os.path.join(path, name))

