#Python code to convert FERRE grid format spectra to individual fits format files based on user specified stellar parameters
#To be placed in the same directory as MainGrid.dat, IntermediateGrid.dat, HighGrid.dat, CaFeGrid.dat
#Will extract an individual spectrum into a .fits file with a labelling TxxxxgxxxMHxxAMxxCMxx.fits into the appropriate directory - see README file
#ATK April 2019
#AES July 2019 - minor upgrades to directory descriptions.
#AES 24 Oct 2019 - CRDELT1 changed to CDELT1 (Fits standard & required by ATK's Diagnostic_Tool_Draft.py software for plotting ratios of spectra).


import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import math as m


from numpy import loadtxt
from astropy.io import fits
from tqdm import tqdm

import os 

#create directories
cwd=os.getcwd()


#Default number of pixels
Npix=146497

#Starting Wavelength in Angstroms
lambda1=1677.10

#Default binning of spectra (dlambda=0.05 Angstroms)
deltalambda=0.05

print '\n'
print '3 main grids:'
print 'Main Grid (3500-6000K) - Main'
print 'Intermediate T Grid (6000-8000K) - Intermediate'
print 'High T Grid (8000-10000K) - High'
print ' '
print 'Additional grid:'
print '[Ca/Fe]=0 Grid - CaFe'
print ' '

#print 'User must pick an existing grid point - see README file for grid ranges. \n '


#User input of what spectrum they wish to extract, based on stellar parameters. The first option gives the user to choose if they want to select from one of the 3 main grids (MIH) or from the [Ca/Fe]=0 grid (CaFe)

CaFe_input=raw_input('Extract from the 3 main grids with Ca tracking alpha elements (type MIH), \n or from the subgrid with Ca tracking Fe (type CaFe) :')

print '\n'
print 'User must pick an existing grid point - see README file for grid ranges. \n '

print 'Type the numerical values of the parameters only (e.g. 3500 for 3500K): \n'

if CaFe_input=='CaFe':
	Teff=input('Enter the effective temperature of the star you wish to extract (3500 - 6000K in steps of 250K) :')
	Logg=input('Enter the surface gravity of the star you wish to extract (0.0 - 5.0 dex in steps of 0.5 dex) :')
	MH=input('Enter the metallicity ([M/H]) of the star you wish to extract (-2.5,-2.0,-1.5,-1.0,-0.5,0.0,0.5 dex) :')
	AlphaM=input('Enter the alpha enhancement ([alpha/M]) of the star you wish to extract (0.25 dex for [Ca/Fe]=0 grid) :')
	CM=input('Enter the carbon enhancement ([C/M]) of the star you wish to extract (0.25 dex for [Ca/Fe]=0 grid) :')
	

if CaFe_input=='MIH':
	Teff=input('Enter the effective temperature of the star you wish to extract (3500 - 10000K in steps of 250K) :')
	if float(Teff)<6000.1: 
		Logg=input('Enter the surface gravity of the star you wish to extract (0.0 - 5.0 dex in steps of 0.5 dex) :')
		MH=input('Enter the metallicity ([M/H]) of the star you wish to extract (-2.5,-2.0,-1.5,-1.0,-0.5,0.0,0.5 dex) :')
		AlphaM=input('Enter the alpha enhancement ([alpha/M]) of the star you wish to extract (-0.25,0.0,0.25,0.50,0.75 dex) :')
		CM=input('Enter the carbon enhancement ([C/M]) of the star you wish to extract (-0.25,0.0,0.25 dex) :')
		

	if float(Teff)>6000.1 and float(Teff)<8000.1:
		Logg=input('Enter the surface gravity of the star you wish to extract (1.0 - 5.0 dex in steps of 0.5 dex) :')
		MH=input('Enter the metallicity ([M/H]) of the star you wish to extract (-2.5,-2.0,-1.5,-1.0,-0.5,0.0,0.5 dex) :')
		AlphaM=input('Enter the alpha enhancement ([alpha/M]) of the star you wish to extract (-0.25,0.0,0.25,0.50,0.75 dex) :')
		CM=input('Enter the carbon enhancement ([C/M]) of the star you wish to extract (-0.25,0.0,0.25 dex) :')
		

	if float(Teff)>8000.1:
		Logg=input('Enter the surface gravity of the star you wish to extract (2.0 - 5.0 dex in steps of 0.5 dex) :')
		MH=input('Enter the metallicity ([M/H]) of the star you wish to extract (-2.5,-2.0,-1.5,-1.0,-0.5,0.0,0.5 dex) :')
		AlphaM=input('Enter the alpha enhancement ([alpha/M]) of the star you wish to extract (-0.25,0.0,0.25,0.50,0.75 dex) :')
		CM=input('Enter the carbon enhancement ([C/M]) of the star you wish to extract (-0.25,0.0,0.25 dex) :')
		

#For the 3500-6000K grid

if CaFe_input=='MIH' and float(Teff)<6000.1 and float(Teff)>3499.9 and float(Logg)>-0.01 and float(Logg)<5.01 and float(MH)>-2.501 and float(MH)<0.501 and float(AlphaM)>-0.251 and float(AlphaM)<0.751:
	Grid='Main'
	if os.path.exists(cwd+'/Main') == False:
			os.mkdir('Main')
	
#For the 6000-8000K grid
if CaFe_input=='MIH' and  float(Teff)>6000.1 and float(Teff)<8000.1 and float(Logg)>0.99 and float(Logg)<5.01 and float(MH)>-2.501 and float(MH)<0.501 and float(AlphaM)>-0.251 and float(AlphaM)<0.751:
	Grid='Intermediate'
	if os.path.exists(cwd+'/Intermediate') == False:
			os.mkdir('Intermediate')
	
#For the 8000-10000K grid

if CaFe_input=='MIH' and float(Teff)<10000.1 and float(Teff)>8000.1 and float(Logg)>1.99 and float(Logg)<5.01 and float(MH)>-2.501 and float(MH)<0.501 and float(AlphaM)>-0.251 and float(AlphaM)<0.751:
	Grid='High'
	if os.path.exists(cwd+'/High') == False:
			os.mkdir('High')


#For the [Ca/Fe]=0 grid

if CaFe_input=='CaFe' and float(Teff)<6000.1 and float(Teff)>3499.9 and float(Logg)>-0.01 and float(Logg)<5.01 and float(MH)>-2.501 and float(MH)<0.501 and float(AlphaM)==0.25 and float(AlphaM)==0.25:
	Grid='CaFe'
	if os.path.exists(cwd+'/CaFe')==False:
		os.mkdir('CaFe')	

#Outside of Grid Ranges	: Print Error
	
if float(Teff>10000) or float(Teff<3500) or float(Logg)<0.0 or float(Logg)>5.0 or float(MH)<-2.5 or float(MH)>0.5 or float(AlphaM)>0.75 or float(AlphaM)<-0.25 or (float(Teff)>6000.1 and float(Logg<1.00)) or (float(Teff)>8000.1 and float(Logg<2.0)):
	print 'Error : Parameters are outside the range of the grids - see the README file for the ranges available'

	Grid='Error'

#Print the parameters back to screen with the directory location that the spectrum will be placed in

if Grid!='Error':
	
	print '\n'
	print 'Extracting Star Spectrum with Parameters :'

	print 'Teff (K) =',float(Teff)
	print 'log g (dex) =',float(Logg)
	print '[M/H] (dex) =',float(MH)
	print '[Alpha/M] (dex) =',float(AlphaM)
	print '[C/M] (dex) =',float(CM)

	print ''
	print 'This spectrum will be placed in the',Grid,'grid directory. \n'

if Grid=='Error':

	print ''
	print 'Error in grid or parameters. No spectrum will be produced. \n'


#Main Grid Extraction

if Grid in ['Main','main']:
	filename='MainGrid.dat'

	with open('3500_6000K_Grid_Order.txt') as file1:
		lines=file1.readlines()

	for i in range(2,len(lines)):
		if (float(lines[i].split()[1])==float(MH) and float(lines[i].split()[2])==float(AlphaM) and float(lines[i].split()[3])==float(CM) and float(lines[i].split()[4])==float(Teff) and float(lines[i].split()[5])==float(Logg)):
			spec_line=float(lines[i].split()[0][lines[i].split()[0].find('m')+1:])-1
			spec_line=int(spec_line)


	with open(filename,'r') as f:
		#skip the headers
		for a in range(16):
			next(f)
		#read the spectra line by line and find the line of the spectrum the user wants
		for index, line in enumerate(f):
			if index==spec_line:
				data=line.split(' ')
				data=data[0:-1]
				data=np.array(data,dtype=float)
		#write each spectrum to a .fits file
				hdu1=fits.PrimaryHDU()
				hdu1.data=data
				hdu1.header['CRPIX1']=1.0
				hdu1.header['CRVAL1']=lambda1
				hdu1.header['CDELT1']=deltalambda
				hdu1.header['NAXIS1']=Npix
				hdu1.header['CD1_1']=deltalambda
				hdu1.header['LTM1_1']=1.0   
				hdu1.writeto(cwd+'/Main/'+str('T')+str(Teff)+str('g')+str(Logg)+str('MH')+str(MH)+str('AM')+str(AlphaM)+str('CM')+str(CM)+'.fits',overwrite=True)
				break
			

#Intermediate Grid Extraction

if Grid in['Intermediate','intermediate']:
	filename='IntermediateGrid.dat'

	with open('6000_8000K_Grid_Order.txt') as file1:
		lines=file1.readlines()

	for i in range(2,len(lines)):
		if (float(lines[i].split()[1])==float(MH) and float(lines[i].split()[2])==float(AlphaM) and float(lines[i].split()[3])==float(CM) and float(lines[i].split()[4])==float(Teff) and float(lines[i].split()[5])==float(Logg)):
			spec_line=float(lines[i].split()[0][lines[i].split()[0].find('m')+1:])-1
			spec_line=int(spec_line)


	with open(filename,'r') as f:
		#skip the headers
		for a in range(16):
			next(f)
		#read the spectra line by line
		for index, line in enumerate(f):
			if index==spec_line:
				data=line.split(' ')
				data=data[0:-1]
				data=np.array(data,dtype=float)
		#write each spectrum to a .fits file
				hdu1=fits.PrimaryHDU()
				hdu1.data=data
				hdu1.header['CRPIX1']=1.0
				hdu1.header['CRVAL1']=lambda1
				hdu1.header['CDELT1']=deltalambda
				hdu1.header['NAXIS1']=Npix
				hdu1.header['CD1_1']=deltalambda
				hdu1.header['LTM1_1']=1.0   
				hdu1.writeto(cwd+'/Intermediate/'+str('T')+str(Teff)+str('g')+str(Logg)+str('MH')+str(MH)+str('AM')+str(AlphaM)+str('CM')+str(CM)+'.fits',overwrite=True)
				break

#High Grid Extraction

if Grid in ['High','high']:
	filename='HighGrid.dat'

	with open('8000_10000K_Grid_Order.txt') as file1:
		lines=file1.readlines()

	for i in range(2,len(lines)):
		if (float(lines[i].split()[1])==float(MH) and float(lines[i].split()[2])==float(AlphaM) and float(lines[i].split()[3])==float(CM) and float(lines[i].split()[4])==float(Teff) and float(lines[i].split()[5])==float(Logg)):
			spec_line=float(lines[i].split()[0][lines[i].split()[0].find('m')+1:])-1
			spec_line=int(spec_line)


	with open(filename,'r') as f:
		#skip the headers
		for a in range(16):
			next(f)
		#read the spectra line by line
		for index, line in enumerate(f):
			if index==spec_line:
				data=line.split(' ')
				data=data[0:-1]
				data=np.array(data,dtype=float)
		#write each spectrum to a .fits file
				hdu1=fits.PrimaryHDU()
				hdu1.data=data
				hdu1.header['CRPIX1']=1.0
				hdu1.header['CRVAL1']=lambda1
				hdu1.header['CDELT1']=deltalambda
				hdu1.header['NAXIS1']=Npix
				hdu1.header['CD1_1']=deltalambda
				hdu1.header['LTM1_1']=1.0   
				hdu1.writeto(cwd+'/High/'+str('T')+str(Teff)+str('g')+str(Logg)+str('MH')+str(MH)+str('AM')+str(AlphaM)+str('CM')+str(CM)+'.fits',overwrite=True)
				break


#CaFe Grid Extraction

if Grid in ['CaFe','cafe','Cafe']:
	filename='CaFeGrid.dat'

	with open('CaFe_Grid_Order.txt') as file1:
		lines=file1.readlines()
	
	for i in range(2,len(lines)):
		if (float(lines[i].split()[1])==float(MH) and float(lines[i].split()[2])==float(AlphaM) and float(lines[i].split()[3])==float(CM) and float(lines[i].split()[4])==float(Teff) and float(lines[i].split()[5])==float(Logg)):
			spec_line=float(lines[i].split()[0][lines[i].split()[0].find('m')+1:])-1
			spec_line=int(spec_line)


	with open(filename,'r') as f:
		#skip the headers
		for a in range(16):
			next(f)
		#read the spectra line by line
		for index, line in enumerate(f):
			if index==spec_line:
				data=line.split(' ')
				data=data[0:-1]
				data=np.array(data,dtype=float)
		#write each spectrum to a .fits file
				hdu1=fits.PrimaryHDU()
				hdu1.data=data
				hdu1.header['CRPIX1']=1.0
				hdu1.header['CRVAL1']=lambda1
				hdu1.header['CDELT1']=deltalambda
				hdu1.header['NAXIS1']=Npix
				hdu1.header['CD1_1']=deltalambda
				hdu1.header['LTM1_1']=1.0   
				hdu1.writeto(cwd+'/CaFe/'+str('T')+str(Teff)+str('g')+str(Logg)+str('MH')+str(MH)+str('AM')+str(AlphaM)+str('CM')+str(CM)+'.fits',overwrite=True)
				break


	
