################################################################################
# Written By Jared Rennie
################################################################################

# Import Packages
import sys, time, datetime, os
import numpy as np
import geopandas as gpd
import pandas as pd

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader

import shapely.geometry as sgeom
import warnings
warnings.filterwarnings("ignore")

# Define directories
main_directory="/store/sfcnet/datasets/fema_risk"
inShapefile_directory=main_directory+'/input_shapefile'
fema_directory=main_directory+'/source'
outShapefile_directory=main_directory+'/results_shapefile'
plot_directory=main_directory+'/results_plots'

#################################################
# BEGIN PROGRAM
start=time.time()

#################################################
# READ IN SHAPEFILE
input_shapefile=inShapefile_directory+'/cb_2022_us_county_500k.shp'

print("READ IN SHAPEFILE: ",input_shapefile)
geo_shapefile = gpd.read_file(input_shapefile)
geo_shapefile['GEOID2_INT'] = np.array(geo_shapefile['GEOID'].values,dtype='i')

# Get Projection
projection=geo_shapefile.crs

#################################################
# READ IN CSV File (fema)
input_csv=fema_directory+'/NRI_Counties_Prod.csv'
print("READING IN fema DATA: ",input_csv)
data_fema = pd.read_csv(input_csv,sep=',')

# Clean Up
data_fema = data_fema[['State-County FIPS Code','State Name Abbreviation','National Risk Index - Rating - Composite', 'Heat Wave - Hazard Type Risk Index Rating']]
data_fema['GEOID2_INT']=data_fema['State-County FIPS Code'].astype(int)
data_fema=data_fema.sort_values(by=['GEOID2_INT']) # Sort

#################################################
# Join
print("\nJOIN")

# Perform the Join
out_shapefile=geo_shapefile.merge(data_fema, on='GEOID2_INT', how='left')

# Set Projection to same as Shapefile
out_shapefile=out_shapefile.set_crs(projection)

# Save as new shapefile
finalShapeFile=outShapefile_directory+'/cdc_fema_fromcsv.shp'
print("OUTPUT TO: "+str(finalShapeFile))
out_shapefile.to_file(finalShapeFile)

#################################################
# PLOTTING (FEMA RISK)
print("PLOTTING (FEMA RISK)")

# Set Bounds
minLat = 22    
maxLat = 50   
minLon = -120 
maxLon = -73 

dpi=300
plt.style.use('dark_background')
land_hex='#efefef'  # ESRI Light Gray Canvas
ocean_hex='#cfd3d4' # ESRI Light Gray Canvas

# Grab Data By variable and plot
#inCode='National Risk Index - Rating - Composite'
#inTitle='FEMA National Risk Index (Dec 2025)'
#inName='fema'

inCode='Heat Wave - Hazard Type Risk Index Rating'
inTitle='FEMA National Risk Index (Heat Waves Only, Dec 2025)'
inName='fema'

# Set Up Figure
fig= plt.figure(num=1, figsize=(8,5), dpi=dpi, facecolor='w', edgecolor='k')

# CONUS AXES
conus_ax = fig.add_axes([0, 0, 1, 1], projection=ccrs.LambertConformal())
conus_ax.set_facecolor(ocean_hex)
conus_ax.set_extent([-120, -73, 22, 50], crs=ccrs.Geodetic())  

# ALASKA AXES
ak_ax = fig.add_axes([0.05, 0.01, 0.20, 0.20], projection=ccrs.Orthographic(central_longitude=-133.66666667, central_latitude=57.00000000))
ak_ax.set_facecolor(ocean_hex)
ak_ax.set_extent([-184, -128, 67, 53], crs=ccrs.Geodetic())  

# HAWAII AXES
hi_ax = fig.add_axes([0.25, 0.01, 0.15, 0.15], projection=ccrs.Mercator())
hi_ax.set_facecolor(ocean_hex)
hi_ax.set_extent([-162, -154, 18, 23], crs=ccrs.Geodetic())  

# PUERTO RICO AXES
pr_ax = fig.add_axes([0.60, 0.01, 0.15, 0.15], projection=ccrs.Mercator())
pr_ax.set_facecolor(ocean_hex)
pr_ax.set_extent([-67.5, -65.5, 17.75, 18.75], crs=ccrs.Geodetic())  

# Plot Data For Each County
attribute_counter=0
for county in shpreader.Reader(finalShapeFile).geometries():
    val=out_shapefile.iloc[attribute_counter][inCode]
    stateFips=out_shapefile.iloc[attribute_counter]['State Name Abbreviation']

    outColor='#9E9E9E'
    if val == 'Insuffucient Data':
        outColor='#9E9E9E'
    if val == 'Not Applicable' or val == 'nan':
        outColor='#CCCCCC'
    if val == 'No Rating':
        outColor='#FFFFFF'
    if val == 'Very Low':
        outColor='#4D6DBD' 
    if val == 'Relatively Low':
        outColor='#509bc7'
    if val == 'Relatively Moderate':
        outColor='#f0d55d'
    if val == 'Relatively High':
        outColor='#e07068'
    if val == 'Very High':
        outColor='#c7445d' 
        
    if stateFips =='AK': #AK
        #print('AK')
        ak_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.20)
    if stateFips =='HI': #HI
        #print('HI')
        hi_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.50)
    if stateFips =='PR': #PR
        #print('PR')
        pr_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.20)
    else:
        conus_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.10)
    attribute_counter+=1
conus_ax.add_feature(cfeature.STATES,linewidth=0.5,zorder=10)

# Add ColorMap 
cmap = mpl.colors.ListedColormap(['#9E9E9E','#CCCCCC','#FFFFFF','#4D6DBD','#509bc7','#f0d55d','#e07068','#c7445d'])
bounds = np.arange(cmap.N+1) 
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
cax = fig.add_axes([0.1, -0.035, 0.8, 0.03])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap=cmap, norm=norm), cax=cax, boundaries=bounds, ticks=bounds, spacing='uniform', orientation='horizontal')

# Define tick locations and labels
labels=np.array(['Insuffucient Data','Not Applicable','No Rating','Very Low','Relatively Low', 'Relatively Moderate', 'Relatively High', 'Very High'],dtype='str')
tick_locations = np.arange(0.5, float(len(labels)+0.5), 1)  # Adjust the number of tick locations to match the number of boundaries
cbar.set_ticks(tick_locations)
cbar.set_ticklabels(labels)
cbar.ax.tick_params(labelsize=6)

# Add Titles
plt.suptitle(inTitle,size=15,color='white',y=1.05) 
plt.annotate('Source: FEMA NRI v1.20 (Dec 2025)\nMade By Jared Rennie (@jjrennie)',xy=(1.045, -3.51), xycoords='axes fraction', fontsize=7,backgroundcolor='black',color='white',horizontalalignment='right', verticalalignment='bottom')

# Save Figure
plt.savefig(plot_directory+"/femaRisk_HEAT.png",bbox_inches='tight') 
plt.clf()
plt.close()

####################
# DONE
####################
print("DONE!")
end=time.time()
print(("Runtime: %8.1f seconds" % (end-start)))
sys.exit()
