"""
@brief 4D Contour Plot of (X,Y (n2,n1) meshes; Z-vector(n4), ; W-matrix(n4,n2,n1); with w-value being "n3"
@author A. Connors
"""
# @file run_pylab_plot_4example.py
#

from numpy import *
from matplotlib import *
from matplotlib.pylab import *
import pyfits
import matplotlib.axes3d as axes3d
import pylab_plot4d as axes4d


###---------------------------------------------------------###
class Get4DExampleData:
    def __init__(self,InFilePathName,
        inXlabel='X',inYlabel='Y',inZWlabel='W',inZlabels=[],inTitle='Results',inXRanged=[],inYRanged=[],inXYPadd=[0.,0.]):
        # For debug,print inputs:
        print '\n Inside Get4DExampleData: Inputs:'
        print inXlabel,inYlabel,inZWlabel,inZlabels,inTitle,inXRanged,inYRanged,inXYPadd
        
        WorkingHDULst = pyfits.open(InFilePathName)
        
        print WorkingHDULst[0].header
        print WorkingHDULst[0].data.shape
        print WorkingHDULst[0].data[0]
        
        WorkingHDULst.close()

        # Now get the size of the data array:
        (naxis3, naxis2, naxis1) = WorkingHDULst[0].data.shape

        ## But using a different X, Y:
        ## Usually I am assuming this is a 360+padding wide, 180+padding high skymap.
        ## The padding (used to make it even powers of 2 bins) is enough
        ## to make it +/- 12 degrees bigger in X; and +/- 19 degrees bigger in Y.
        ##
        ## For general work, these should be inputs, I think:
        #  SO Here's the "input" padding, which one may wish to mask or chop:
        xpaddeg  , ypaddeg   = inXYPadd
        print 'Initializing xpaddeg, ypaddeg:', xpaddeg,' , ', ypaddeg
        #  AND here's the Actual Size of the Working Matrix, with the "padding" already in:
        #  (Sometime, one may wish to try reading the header for crval-n, cdelt-n, etc.
        #  For this, it is assumed it was dne, if wished, in the calling routine.)
        try:
            xrangedeg, yrangedeg = (inXRanged[1]-inXRanged[0]),  (inYRanged[1]-inYRanged[0])
            print 'From Usual: Xrange, Yrange:', inXlabel,xrangedeg,' ; ',inYlabel, yrangedeg, '\n'
        except:
            xrangedeg, yrangedeg = float(naxis1), float(naxis2)
            print 'From Exception: Xrange, Yrange:', inXlabel,xrangedeg,' ; ',inYlabel, yrangedeg, '\n'
        
        xdelt    , ydelt     = xrangedeg/float(naxis1), yrangedeg/float(naxis2)
        print ' Xbins, Delt(degrees); Ybins, Delt(degrees):', naxis1, xdelt, naxis2, ydelt, '\n'
        
        xmin, xmax, ymin, ymax = inXRanged[0], inXRanged[1], inYRanged[0], inYRanged[1]
        print 'Xmin, Xmax; Ymin, Ymax:', xmin, xmax, ' ; ', ymin, ymax, '\n \n'
        
        x   , y    = arange(xmin,xmax,xdelt), arange(ymin,ymax,ydelt)
        
        #print 'X: \n',  x, '\n \n'
        #print 'Y: \n', y, '\n \n'
        
        ### And here I am trying out "meshgrid":
        XBig, YBig = meshgrid(x,y)
        
        ### Later I might use a histogram to get
        ### Tom's "agrid"-like actual HPD 90% limits
        ### And plot THOSE, for each image, instead
        
        # Now I want to chop off the padded edges of all the grids:
        ixchop, iychop = int(xpaddeg/2./xdelt), int(ypaddeg/2./ydelt)
        
        print ' Now chopping of: ixchop = ',ixchop,'  , iychop = ', iychop, '\n'
        W = zeros((naxis3,naxis2-2*iychop,naxis1-2*ixchop),Float)
        Y = zeros((naxis2-2*iychop,naxis1-2*ixchop),Float)
        X = zeros((naxis2-2*iychop,naxis1-2*ixchop),Float)
        print ' So W.shape should be: ', W.shape
        
        for izz in range(naxis3):
            for iyy in range(iychop,naxis2-iychop):
                iiyy = iyy - iychop
                for ixx in range(ixchop,naxis1-ixchop):
                    iixx = ixx - ixchop
                    W[izz][iiyy][iixx] = float(WorkingHDULst[0].data[izz][iiyy][iixx])
                    if(izz == 0):
                        X[iiyy][iixx], Y[iiyy][iixx] = XBig[iyy][ixx], YBig[iyy][ixx]
        print 'W shape and value: ', W.shape, '\n', W
        
        print 'X shape and value:', X.shape, '\n', X
        print ' Y shape and value:',Y.shape,'\n', Y, '\n \n'
        
        ## And now the Z-machinations:
        wmin, wmax = [], []
        ZFlat = []
        lenw = len(W)
        for iz in range(lenw):
            wmin.append(W[iz].min()), wmax.append(W[iz].max())
            ZFlat.append(float(iz))
        
        (n3,n2,n1) = W.shape
        lenZW = n3+1
        
        Z3D = ZFlat
        print ' wmin, wmax:', wmin,' ; ',wmax, '\n \n'
        wdel = []
        wmin2 = wmin[0]
        wmax2 = wmax[0]
        zdel = 2.
        for ii3 in range(lenw):
            wmin2, wmax2 = minimum(wmin2,wmin[ii3]), maximum(wmax2, wmax[ii3])
            print ' ii3: ', ii3,
            print ' wmin[ii3]: ', wmin[ii3],
            print ' wmax[ii3]: ', wmax[ii3],
            print '  wmin2, wmax2: ',wmin2,' , ', wmax2, '\n'
            wdel.append(wmax[ii3]-wmin[ii3])
            zdel = max( [zdel,2.5*(wdel[ii3])] )
        
        zrange = float(n3-1)*zdel
        
        Z3D[0] = -0.5*zrange
        
        print 'Zdel, Zrange: ', zdel, ' ; ', zrange, ' \n'
        
        for i3 in range(lenZW-1):
            try:
                Z3D[i3+1] = Z3D[i3] + zdel
            except:
                ztmp = Z3D[i3]
                Z3D.append(ztmp + zdel)
                continue
            
        print 'For i3 of ',i3,' , wmin3,wmax3 are: ',wmin[i3], wmax[i3],'  and Z3D[i3], [i3+1] are: ', Z3D[i3], ' , ', Z3D[i3+1]
        
        wmin2, wmax2 = min(wmin), max(wmax)
        
        #print '\n \n AND FINALLY, Just before Contour4D call, wmin2, wmax2:  ',wmin2, wmax2
        self.X, self.Y, self.Z, self.W = X, Y, Z3D[0:n3], W
        self.Xlabel, self.Ylabel, self.ZWlabel, self.Zlabels, self.title = inXlabel, inYlabel, inZWlabel, inZlabels, inTitle
        self.zdel = zdel
        print 'Initializing Get4DExampleData instance Done. \n \n'
       
   ####----------------------------------------###################
   #### Now a "stacked contour" plot method for this XYZW data(+labels) object:
    
    def PlotStacked4DContours(self,
            numcontours= 32, PlotFileName='StackContour4D',PlotType='PNG', text_params={'fontsize':'xx-large'}):
    
        ## Copying from simple3d.py:
        ##
        
        fig = gcf()
        ax4d = axes4d.Axes4D(fig)
        plt = fig.axes.append(ax4d)
        W = self.W
        lenw = len(W)
        wmin2,wmax2 = W.min(), W.max()
        #Get wmax,wmin,wmax2,wmin2
        
        (n3, n2, n1) = W.shape
        #  I am putting a bar of height (wmax -wmin) into the 0,0 position:
        #  This takes the place of a colorbar() call:
        for kz in range(n3):
            for jy in range(4):
                W[kz][jy][0] = wmax2-wmin2
        
        conttest = ax4d.Stack3DContour4D( self.X, self.Y, self.Z, W, numcontours ,cmap=cm.jet,**text_params)
        
        #  For now these are hard-coded; could be better:
        
        ax4d.set_xlabel(self.Xlabel ,text_params)
        ax4d.set_ylabel(self.Ylabel ,text_params)
        ax4d.set_zlabel(self.ZWlabel,text_params)
        
        xpoint = self.X[n2-1][n1-1]*0.75
        ypoint = self.Y[n2-1][n1-1]*0.75
        zup = 0.12*self.zdel
        
        for llnn in range(len(self.Z)):
            ax4d.text3D(xpoint,ypoint,self.Z[llnn]+zup,self.Zlabels[llnn],text_params)
        
        ax4d.text3D(self.X[n2-1][0],self.Y[n2-1][0],    self.Z[2]+0.3*self.zdel,self.title,text_params)
        #ax4d.set_xticklabels('visible' : 'False')
        #ax4d.set_yticklabels('visible' : 'False')
        
        ##
        if (PlotType=='PNG'):
            savefig(PlotFileName+'.png')
        elif (PlotType=='PS'):
            savefig(PlotFileName+'.ps')
        else:
            show()
    
    #--------end of test??------------------#
