Source code for galpy.util.plot

##############################################################################
#
#   plot.py: general wrappers for matplotlib plotting
#
#       'public' methods:
#                         end_print
#                         dens2d
#                         hist
#                         plot
#                         start_print
#                         scatterplot (like hogg_scatterplot)
#                         text
#
#                         this module also defines a custom matplotlib 
#                         projection in which the polar azimuth increases
#                         clockwise (as in, the Galaxy viewed from the NGP)
#                         
#############################################################################
#############################################################################
#Copyright (c) 2010 - 2020, Jo Bovy
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without 
#modification, are permitted provided that the following conditions are met:
#
#   Redistributions of source code must retain the above copyright notice, 
#      this list of conditions and the following disclaimer.
#   Redistributions in binary form must reproduce the above copyright notice, 
#      this list of conditions and the following disclaimer in the 
#      documentation and/or other materials provided with the distribution.
#   The name of the author may not be used to endorse or promote products 
#      derived from this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
#"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
#LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
#A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
#HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
#INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
#BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
#OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
#AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
#LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
#WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#POSSIBILITY OF SUCH DAMAGE.
#############################################################################
import re
from pkg_resources import parse_version
import numpy
from scipy import special
from scipy import interpolate
from scipy import ndimage
import matplotlib.pyplot as pyplot
import matplotlib.ticker as ticker
import matplotlib.cm as cm
import matplotlib
from matplotlib import rc
from matplotlib.ticker import NullFormatter
from matplotlib.projections import PolarAxes, register_projection
from matplotlib.transforms import Affine2D, Bbox, IdentityTransform
from mpl_toolkits.mplot3d import Axes3D # Necessary for 3D plotting (projection = '3d')
_MPL_VERSION= parse_version(matplotlib.__version__)
from ..util.config import __config__
if __config__.getboolean('plot','seaborn-bovy-defaults'):
    try:
        import seaborn as sns
    except: pass
    else:
        sns.set_style('ticks',
                      {'xtick.direction': u'in',
                       'ytick.direction': u'in',
                       'axes.labelsize': 18.0,
                       'axes.titlesize': 18.0,
                       'figure.figsize': numpy.array([ 6.64,  4.  ]),
                       'grid.linewidth': 2.0,
                       'legend.fontsize': 18.0,
                       'lines.linewidth': 2.0,
                       'lines.markeredgewidth': 0.0,
                       'lines.markersize': 14.0,
                       'patch.linewidth': 0.6,
                       'xtick.labelsize': 16.0,
                       'xtick.major.pad': 14.0,
                       'xtick.major.width': 2.0,
                       'xtick.minor.width': 1.0,
                       'ytick.labelsize': 16.0,
                       'ytick.major.pad': 14.0,
                       'ytick.major.width': 2.0,})
_DEFAULTNCNTR= 10
[docs]def end_print(filename,**kwargs): """ NAME: end_print PURPOSE: saves the current figure(s) to filename INPUT: filename - filename for plot (with extension) OPTIONAL INPUTS: format - file-format OUTPUT: (none) HISTORY: 2009-12-23 - Written - Bovy (NYU) """ if 'format' in kwargs: pyplot.savefig(filename,**kwargs) else: pyplot.savefig(filename,format=re.split(r'\.',filename)[-1],**kwargs) pyplot.close()
[docs]def hist(x,xlabel=None,ylabel=None,overplot=False,**kwargs): """ NAME: hist PURPOSE: wrapper around matplotlib's hist function INPUT: x - array to histogram xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed yrange - set the y-axis range +all pyplot.hist keywords OUTPUT: (from the matplotlib docs: http://matplotlib.sourceforge.net/api/pyplot_api.html#matplotlib.pyplot.hist) The return value is a tuple (n, bins, patches) or ([n0, n1, ...], bins, [patches0, patches1,...]) if the input contains multiple data HISTORY: 2009-12-23 - Written - Bovy (NYU) """ if not overplot: pyplot.figure() if 'xrange' in kwargs: xlimits= kwargs.pop('xrange') if not 'range' in kwargs: kwargs['range']= xlimits xrangeSet= True else: xrangeSet= False if 'yrange' in kwargs: ylimits= kwargs.pop('yrange') yrangeSet= True else: yrangeSet= False out= pyplot.hist(x,**kwargs) if overplot: return out _add_axislabels(xlabel,ylabel) if not 'range' in kwargs and not xrangeSet: if isinstance(x,list): xlimits=(numpy.array(x).min(),numpy.array(x).max()) else: pyplot.xlim(x.min(),x.max()) elif xrangeSet: pyplot.xlim(xlimits) else: pyplot.xlim(kwargs['range']) if yrangeSet: pyplot.ylim(ylimits) _add_ticks() return out
[docs]def plot(*args,**kwargs): """ NAME: plot PURPOSE: wrapper around matplotlib's plot function INPUT: see http://matplotlib.sourceforge.net/api/pyplot_api.html#matplotlib.pyplot.plot xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed xrange yrange scatter= if True, use pyplot.scatter and its options etc. colorbar= if True, and scatter==True, add colorbar crange - range for colorbar of scatter==True clabel= label for colorbar overplot=True does not start a new figure and does not change the ranges and labels gcf=True does not start a new figure (does change the ranges and labels) onedhists - if True, make one-d histograms on the sides onedhistcolor, onedhistfc, onedhistec onedhistxnormed, onedhistynormed - normed keyword for one-d histograms onedhistxweights, onedhistyweights - weights keyword for one-d histograms bins= number of bins for onedhists semilogx=, semilogy=, loglog= if True, plot logs OUTPUT: plot to output device, returns what pyplot.plot returns, or 3 Axes instances if onedhists=True HISTORY: 2009-12-28 - Written - Bovy (NYU) """ overplot= kwargs.pop('overplot',False) gcf= kwargs.pop('gcf',False) onedhists= kwargs.pop('onedhists',False) scatter= kwargs.pop('scatter',False) loglog= kwargs.pop('loglog',False) semilogx= kwargs.pop('semilogx',False) semilogy= kwargs.pop('semilogy',False) colorbar= kwargs.pop('colorbar',False) onedhisttype= kwargs.pop('onedhisttype','step') onedhistcolor= kwargs.pop('onedhistcolor','k') onedhistfc= kwargs.pop('onedhistfc','w') onedhistec= kwargs.pop('onedhistec','k') onedhistxnormed= kwargs.pop('onedhistxnormed',True) onedhistynormed= kwargs.pop('onedhistynormed',True) onedhistxweights= kwargs.pop('onedhistxweights',None) onedhistyweights= kwargs.pop('onedhistyweights',None) if 'bins' in kwargs: bins= kwargs['bins'] kwargs.pop('bins') elif onedhists: if isinstance(args[0],numpy.ndarray): bins= round(0.3*numpy.sqrt(args[0].shape[0])) elif isinstance(args[0],list): bins= round(0.3*numpy.sqrt(len(args[0]))) else: bins= 30 if onedhists: if overplot or gcf: fig= pyplot.gcf() else: fig= pyplot.figure() nullfmt = NullFormatter() # no labels # definitions for the axes left, width = 0.1, 0.65 bottom, height = 0.1, 0.65 bottom_h = left_h = left+width rect_scatter = [left, bottom, width, height] rect_histx = [left, bottom_h, width, 0.2] rect_histy = [left_h, bottom, 0.2, height] axScatter = pyplot.axes(rect_scatter) axHistx = pyplot.axes(rect_histx) axHisty = pyplot.axes(rect_histy) # no labels axHistx.xaxis.set_major_formatter(nullfmt) axHistx.yaxis.set_major_formatter(nullfmt) axHisty.xaxis.set_major_formatter(nullfmt) axHisty.yaxis.set_major_formatter(nullfmt) fig.sca(axScatter) elif not overplot and not gcf: pyplot.figure() ax=pyplot.gca() ax.set_autoscale_on(False) xlabel= kwargs.pop('xlabel',None) ylabel= kwargs.pop('ylabel',None) clabel= kwargs.pop('clabel',None) xlimits= kwargs.pop('xrange',None) if xlimits is None: if isinstance(args[0],list): xlimits=(numpy.array(args[0]).min(),numpy.array(args[0]).max()) else: xlimits=(args[0].min(),args[0].max()) ylimits= kwargs.pop('yrange',None) if ylimits is None: if isinstance(args[1],list): ylimits=(numpy.array(args[1]).min(),numpy.array(args[1]).max()) else: ylimits=(args[1].min(),args[1].max()) climits= kwargs.pop('crange',None) if climits is None and scatter: if 'c' in kwargs and isinstance(kwargs['c'],list): climits=(numpy.array(kwargs['c']).min(),numpy.array(kwargs['c']).max()) elif 'c' in kwargs: climits=(kwargs['c'].min(),kwargs['c'].max()) else: climits= None if scatter: out= pyplot.scatter(*args,**kwargs) elif loglog: out= pyplot.loglog(*args,**kwargs) elif semilogx: out= pyplot.semilogx(*args,**kwargs) elif semilogy: out= pyplot.semilogy(*args,**kwargs) else: out= pyplot.plot(*args,**kwargs) if overplot: pass else: if semilogy: ax= pyplot.gca() ax.set_yscale('log') elif semilogx: ax= pyplot.gca() ax.set_xscale('log') elif loglog: ax= pyplot.gca() ax.set_xscale('log') ax.set_yscale('log') pyplot.xlim(*xlimits) pyplot.ylim(*ylimits) _add_axislabels(xlabel,ylabel) if not semilogy and not semilogx and not loglog: _add_ticks() elif semilogy: _add_ticks(xticks=True,yticks=False) elif semilogx: _add_ticks(yticks=True,xticks=False) #Add colorbar if colorbar: cbar= pyplot.colorbar(out,fraction=0.15) if _MPL_VERSION < parse_version('3.1'): # pragma: no cover # https://matplotlib.org/3.1.0/api/api_changes.html#colorbarbase-inheritance cbar.set_clim(*climits) else: cbar.mappable.set_clim(*climits) if not clabel is None: cbar.set_label(clabel) #Add onedhists if not onedhists: return out histx, edges, patches= axHistx.hist(args[0], bins=bins, normed=onedhistxnormed, weights=onedhistxweights, histtype=onedhisttype, range=sorted(xlimits), color=onedhistcolor,fc=onedhistfc, ec=onedhistec) histy, edges, patches= axHisty.hist(args[1], bins=bins, orientation='horizontal', weights=onedhistyweights, normed=onedhistynormed, histtype=onedhisttype, range=sorted(ylimits), color=onedhistcolor,fc=onedhistfc, ec=onedhistec) axHistx.set_xlim( axScatter.get_xlim() ) axHisty.set_ylim( axScatter.get_ylim() ) axHistx.set_ylim( 0, 1.2*numpy.amax(histx)) axHisty.set_xlim( 0, 1.2*numpy.amax(histy)) return (axScatter,axHistx,axHisty)
def plot3d(*args,**kwargs): """ NAME: plot3d PURPOSE: plot in 3d much as in 2d INPUT: see http://matplotlib.sourceforge.net/api/pyplot_api.html#matplotlib.pyplot.plot xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed xrange yrange overplot=True does not start a new figure OUTPUT: HISTORY: 2011-01-08 - Written - Bovy (NYU) """ overplot= kwargs.pop('overplot',False) if not overplot: pyplot.figure() ax=pyplot.gca(projection='3d') ax.set_autoscale_on(False) xlabel= kwargs.pop('xlabel',None) ylabel= kwargs.pop('ylabel',None) zlabel= kwargs.pop('zlabel',None) if 'xrange' in kwargs: xlimits= kwargs.pop('xrange') else: if isinstance(args[0],list): xlimits=(numpy.array(args[0]).min(),numpy.array(args[0]).max()) else: xlimits=(args[0].min(),args[0].max()) if 'yrange' in kwargs: ylimits= kwargs.pop('yrange') else: if isinstance(args[1],list): ylimits=(numpy.array(args[1]).min(),numpy.array(args[1]).max()) else: ylimits=(args[1].min(),args[1].max()) if 'zrange' in kwargs: zlimits= kwargs.pop('zrange') else: if isinstance(args[2],list): zlimits=(numpy.array(args[2]).min(),numpy.array(args[2]).max()) else: zlimits=(args[1].min(),args[2].max()) out= pyplot.plot(*args,**kwargs) if overplot: pass else: if xlabel != None: if xlabel[0] != '$': thisxlabel=r'$'+xlabel+'$' else: thisxlabel=xlabel ax.set_xlabel(thisxlabel) if ylabel != None: if ylabel[0] != '$': thisylabel=r'$'+ylabel+'$' else: thisylabel=ylabel ax.set_ylabel(thisylabel) if zlabel != None: if zlabel[0] != '$': thiszlabel=r'$'+zlabel+'$' else: thiszlabel=zlabel ax.set_zlabel(thiszlabel) ax.set_xlim3d(*xlimits) ax.set_ylim3d(*ylimits) ax.set_zlim3d(*zlimits) return out
[docs]def dens2d(X,**kwargs): """ NAME: dens2d PURPOSE: plot a 2d density with optional contours INPUT: first argument is the density matplotlib.pyplot.imshow keywords (see http://matplotlib.sourceforge.net/api/axes_api.html#matplotlib.axes.Axes.imshow) xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed xrange yrange noaxes - don't plot any axes overplot - if True, overplot colorbar - if True, add colorbar shrink= colorbar argument: shrink the colorbar by the factor (optional) conditional - normalize each column separately (for probability densities, i.e., cntrmass=True) gcf=True does not start a new figure (does change the ranges and labels) Contours: justcontours - if True, only draw contours contours - if True, draw contours (10 by default) levels - contour-levels cntrmass - if True, the density is a probability and the levels are probability masses contained within the contour cntrcolors - colors for contours (single color or array) cntrlabel - label the contours cntrlw, cntrls - linewidths and linestyles for contour cntrlabelsize, cntrlabelcolors,cntrinline - contour arguments cntrSmooth - use ndimage.gaussian_filter to smooth before contouring onedhists - if True, make one-d histograms on the sides onedhistcolor - histogram color retAxes= return all Axes instances retCont= return the contour instance OUTPUT: plot to output device, Axes instances depending on input HISTORY: 2010-03-09 - Written - Bovy (NYU) """ overplot= kwargs.pop('overplot',False) gcf= kwargs.pop('gcf',False) if not overplot and not gcf: pyplot.figure() xlabel= kwargs.pop('xlabel',None) ylabel= kwargs.pop('ylabel',None) zlabel= kwargs.pop('zlabel',None) if 'extent' in kwargs: extent= kwargs.pop('extent') else: xlimits= kwargs.pop('xrange',[0,X.shape[1]]) ylimits= kwargs.pop('yrange',[0,X.shape[0]]) extent= xlimits+ylimits if not 'aspect' in kwargs: kwargs['aspect']= (xlimits[1]-xlimits[0])/float(ylimits[1]-ylimits[0]) noaxes= kwargs.pop('noaxes',False) justcontours= kwargs.pop('justcontours',False) if ('contours' in kwargs and kwargs['contours']) or \ 'levels' in kwargs or justcontours or \ ('cntrmass' in kwargs and kwargs['cntrmass']): contours= True else: contours= False kwargs.pop('contours',None) if 'levels' in kwargs: levels= kwargs['levels'] kwargs.pop('levels') elif contours: if 'cntrmass' in kwargs and kwargs['cntrmass']: levels= numpy.linspace(0.,1.,_DEFAULTNCNTR) elif True in numpy.isnan(numpy.array(X)): levels= numpy.linspace(numpy.nanmin(X),numpy.nanmax(X),_DEFAULTNCNTR) else: levels= numpy.linspace(numpy.amin(X),numpy.amax(X),_DEFAULTNCNTR) cntrmass= kwargs.pop('cntrmass',False) conditional= kwargs.pop('conditional',False) cntrcolors= kwargs.pop('cntrcolors','k') cntrlabel= kwargs.pop('cntrlabel',False) cntrlw= kwargs.pop('cntrlw',None) cntrls= kwargs.pop('cntrls',None) cntrSmooth= kwargs.pop('cntrSmooth',None) cntrlabelsize= kwargs.pop('cntrlabelsize',None) cntrlabelcolors= kwargs.pop('cntrlabelcolors',None) cntrinline= kwargs.pop('cntrinline',None) retCumImage= kwargs.pop('retCumImage',False) cb= kwargs.pop('colorbar',False) shrink= kwargs.pop('shrink',None) onedhists= kwargs.pop('onedhists',False) onedhistcolor= kwargs.pop('onedhistcolor','k') retAxes= kwargs.pop('retAxes',False) retCont= kwargs.pop('retCont',False) if onedhists: if overplot or gcf: fig= pyplot.gcf() else: fig= pyplot.figure() nullfmt = NullFormatter() # no labels # definitions for the axes left, width = 0.1, 0.65 bottom, height = 0.1, 0.65 bottom_h = left_h = left+width rect_scatter = [left, bottom, width, height] rect_histx = [left, bottom_h, width, 0.2] rect_histy = [left_h, bottom, 0.2, height] axScatter = pyplot.axes(rect_scatter) axHistx = pyplot.axes(rect_histx) axHisty = pyplot.axes(rect_histy) # no labels axHistx.xaxis.set_major_formatter(nullfmt) axHistx.yaxis.set_major_formatter(nullfmt) axHisty.xaxis.set_major_formatter(nullfmt) axHisty.yaxis.set_major_formatter(nullfmt) fig.sca(axScatter) ax=pyplot.gca() ax.set_autoscale_on(False) if conditional: plotthis= X/numpy.tile(numpy.sum(X,axis=0),(X.shape[1],1)) else: plotthis= X if not justcontours: out= pyplot.imshow(plotthis,extent=extent,**kwargs) if not overplot: pyplot.axis(extent) _add_axislabels(xlabel,ylabel) _add_ticks() #Add colorbar if cb and not justcontours: if shrink is None: shrink= numpy.amin([float(kwargs.pop('aspect',1.))*0.87,1.]) CB1= pyplot.colorbar(out,shrink=shrink) if not zlabel is None: if zlabel[0] != '$': thiszlabel=r'$'+zlabel+'$' else: thiszlabel=zlabel CB1.set_label(thiszlabel) if contours or retCumImage: aspect= kwargs.get('aspect',None) origin= kwargs.get('origin',None) if cntrmass: #Sum from the top down! plotthis[numpy.isnan(plotthis)]= 0. sortindx= numpy.argsort(plotthis.flatten())[::-1] cumul= numpy.cumsum(numpy.sort(plotthis.flatten())[::-1])/numpy.sum(plotthis.flatten()) cntrThis= numpy.zeros(numpy.prod(plotthis.shape)) cntrThis[sortindx]= cumul cntrThis= numpy.reshape(cntrThis,plotthis.shape) else: cntrThis= plotthis if contours: if not cntrSmooth is None: cntrThis= ndimage.gaussian_filter(cntrThis,cntrSmooth, mode='nearest') cont= pyplot.contour(cntrThis,levels,colors=cntrcolors, linewidths=cntrlw,extent=extent,aspect=aspect, linestyles=cntrls,origin=origin) if cntrlabel: pyplot.clabel(cont,fontsize=cntrlabelsize, colors=cntrlabelcolors, inline=cntrinline) if noaxes: ax.set_axis_off() #Add onedhists if not onedhists: if retCumImage: return cntrThis elif retAxes: return pyplot.gca() elif retCont: return cont elif justcontours: return cntrThis else: return out histx= numpy.nansum(X.T,axis=1)*numpy.fabs(ylimits[1]-ylimits[0])/X.shape[1] #nansum bc nan is *no dens value* histy= numpy.nansum(X.T,axis=0)*numpy.fabs(xlimits[1]-xlimits[0])/X.shape[0] histx[numpy.isnan(histx)]= 0. histy[numpy.isnan(histy)]= 0. dx= (extent[1]-extent[0])/float(len(histx)) axHistx.plot(numpy.linspace(extent[0]+dx,extent[1]-dx,len(histx)),histx, drawstyle='steps-mid',color=onedhistcolor) dy= (extent[3]-extent[2])/float(len(histy)) axHisty.plot(histy,numpy.linspace(extent[2]+dy,extent[3]-dy,len(histy)), drawstyle='steps-mid',color=onedhistcolor) axHistx.set_xlim( axScatter.get_xlim() ) axHisty.set_ylim( axScatter.get_ylim() ) axHistx.set_ylim( 0, 1.2*numpy.amax(histx)) axHisty.set_xlim( 0, 1.2*numpy.amax(histy)) if retCumImage: return cntrThis elif retAxes: return (axScatter,axHistx,axHisty) elif justcontours: return cntrThis else: return out
[docs]def start_print(fig_width=5,fig_height=5,axes_labelsize=16, text_fontsize=11,legend_fontsize=12, xtick_labelsize=10,ytick_labelsize=10, xtick_minor_size=2,ytick_minor_size=2, xtick_major_size=4,ytick_major_size=4): """ NAME: start_print PURPOSE: setup a figure for plotting INPUT: fig_width - width in inches fig_height - height in inches axes_labelsize - size of the axis-labels text_fontsize - font-size of the text (if any) legend_fontsize - font-size of the legend (if any) xtick_labelsize - size of the x-axis labels ytick_labelsize - size of the y-axis labels xtick_minor_size - size of the minor x-ticks ytick_minor_size - size of the minor y-ticks OUTPUT: (none) HISTORY: 2009-12-23 - Written - Bovy (NYU) """ fig_size = [fig_width,fig_height] params = {'axes.labelsize': axes_labelsize, 'font.size': text_fontsize, 'legend.fontsize': legend_fontsize, 'xtick.labelsize':xtick_labelsize, 'ytick.labelsize':ytick_labelsize, 'text.usetex': True, 'figure.figsize': fig_size, 'xtick.major.size' : xtick_major_size, 'ytick.major.size' : ytick_major_size, 'xtick.minor.size' : xtick_minor_size, 'ytick.minor.size' : ytick_minor_size, 'legend.numpoints':1, 'xtick.top': True, 'xtick.direction': 'in', 'ytick.right': True, 'ytick.direction': 'in'} pyplot.rcParams.update(params) rc('text.latex', preamble=r'\usepackage{amsmath}'+'\n' +r'\usepackage{amssymb}')
[docs]def text(*args,**kwargs): """ NAME: text PURPOSE: thin wrapper around matplotlib's text and annotate use keywords: 'bottom_left=True' 'bottom_right=True' 'top_left=True' 'top_right=True' 'title=True' to place the text in one of the corners or use it as the title INPUT: see matplotlib's text (http://matplotlib.sourceforge.net/api/pyplot_api.html#matplotlib.pyplot.text) OUTPUT: prints text on the current figure HISTORY: 2010-01-26 - Written - Bovy (NYU) """ if kwargs.pop('title',False): pyplot.annotate(args[0],(0.5,1.05),xycoords='axes fraction', horizontalalignment='center', verticalalignment='top',**kwargs) elif kwargs.pop('bottom_left',False): pyplot.annotate(args[0],(0.05,0.05),xycoords='axes fraction',**kwargs) elif kwargs.pop('bottom_right',False): pyplot.annotate(args[0],(0.95,0.05),xycoords='axes fraction', horizontalalignment='right',**kwargs) elif kwargs.pop('top_right',False): pyplot.annotate(args[0],(0.95,0.95),xycoords='axes fraction', horizontalalignment='right', verticalalignment='top',**kwargs) elif kwargs.pop('top_left',False): pyplot.annotate(args[0],(0.05,0.95),xycoords='axes fraction', verticalalignment='top',**kwargs) else: pyplot.text(*args,**kwargs)
[docs]def scatterplot(x,y,*args,**kwargs): """ NAME: scatterplot PURPOSE: make a 'smart' scatterplot that is a density plot in high-density regions and a regular scatterplot for outliers INPUT: x, y xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed xrange yrange bins - number of bins to use in each dimension weights - data-weights aspect - aspect ratio conditional - normalize each column separately (for probability densities, i.e., cntrmass=True) gcf=True does not start a new figure (does change the ranges and labels) contours - if False, don't plot contours justcontours - if True, only draw contours, no density cntrcolors - color of contours (can be array as for dens2d) cntrlw, cntrls - linewidths and linestyles for contour cntrSmooth - use ndimage.gaussian_filter to smooth before contouring levels - contour-levels; data points outside of the last level will be individually shown (so, e.g., if this list is descending, contours and data points will be overplotted) onedhists - if True, make one-d histograms on the sides onedhistx - if True, make one-d histograms on the side of the x distribution onedhisty - if True, make one-d histograms on the side of the y distribution onedhistcolor, onedhistfc, onedhistec onedhistxnormed, onedhistynormed - normed keyword for one-d histograms onedhistxweights, onedhistyweights - weights keyword for one-d histograms cmap= cmap for density plot hist= and edges= - you can supply the histogram of the data yourself, this can be useful if you want to censor the data, both need to be set and calculated using scipy.histogramdd with the given range retAxes= return all Axes instances OUTPUT: plot to output device, Axes instance(s) or not, depending on input HISTORY: 2010-04-15 - Written - Bovy (NYU) """ xlabel= kwargs.pop('xlabel',None) ylabel= kwargs.pop('ylabel',None) if 'xrange' in kwargs: xrange= kwargs.pop('xrange') else: if isinstance(x,list): xrange=[numpy.amin(x),numpy.amax(x)] else: xrange=[x.min(),x.max()] if 'yrange' in kwargs: yrange= kwargs.pop('yrange') else: if isinstance(y,list): yrange=[numpy.amin(y),numpy.amax(y)] else: yrange=[y.min(),y.max()] ndata= len(x) bins= kwargs.pop('bins',round(0.3*numpy.sqrt(ndata))) weights= kwargs.pop('weights',None) levels= kwargs.pop('levels',special.erf(numpy.arange(1,4)/numpy.sqrt(2.))) aspect= kwargs.pop('aspect',(xrange[1]-xrange[0])/(yrange[1]-yrange[0])) conditional= kwargs.pop('conditional',False) contours= kwargs.pop('contours',True) justcontours= kwargs.pop('justcontours',False) cntrcolors= kwargs.pop('cntrcolors','k') cntrlw= kwargs.pop('cntrlw',None) cntrls= kwargs.pop('cntrls',None) cntrSmooth= kwargs.pop('cntrSmooth',None) onedhists= kwargs.pop('onedhists',False) onedhistx= kwargs.pop('onedhistx',onedhists) onedhisty= kwargs.pop('onedhisty',onedhists) onedhisttype= kwargs.pop('onedhisttype','step') onedhistcolor= kwargs.pop('onedhistcolor','k') onedhistfc= kwargs.pop('onedhistfc','w') onedhistec= kwargs.pop('onedhistec','k') onedhistls= kwargs.pop('onedhistls','solid') onedhistlw= kwargs.pop('onedhistlw',None) onedhistsbins= kwargs.pop('onedhistsbins',round(0.3*numpy.sqrt(ndata))) overplot= kwargs.pop('overplot',False) gcf= kwargs.pop('gcf',False) cmap= kwargs.pop('cmap',cm.gist_yarg) onedhistxnormed= kwargs.pop('onedhistxnormed',True) onedhistynormed= kwargs.pop('onedhistynormed',True) onedhistxweights= kwargs.pop('onedhistxweights',weights) onedhistyweights= kwargs.pop('onedhistyweights',weights) retAxes= kwargs.pop('retAxes',False) if onedhists or onedhistx or onedhisty: if overplot or gcf: fig= pyplot.gcf() else: fig= pyplot.figure() nullfmt = NullFormatter() # no labels # definitions for the axes left, width = 0.1, 0.65 bottom, height = 0.1, 0.65 bottom_h = left_h = left+width rect_scatter = [left, bottom, width, height] rect_histx = [left, bottom_h, width, 0.2] rect_histy = [left_h, bottom, 0.2, height] axScatter = pyplot.axes(rect_scatter) if onedhistx: axHistx = pyplot.axes(rect_histx) # no labels axHistx.xaxis.set_major_formatter(nullfmt) axHistx.yaxis.set_major_formatter(nullfmt) if onedhisty: axHisty = pyplot.axes(rect_histy) # no labels axHisty.xaxis.set_major_formatter(nullfmt) axHisty.yaxis.set_major_formatter(nullfmt) fig.sca(axScatter) data= numpy.array([x,y]).T if 'hist' in kwargs and 'edges' in kwargs: hist=kwargs['hist'] kwargs.pop('hist') edges=kwargs['edges'] kwargs.pop('edges') else: hist, edges= numpy.histogramdd(data,bins=bins,range=[xrange,yrange], weights=weights) if contours: cumimage= dens2d(hist.T,contours=contours,levels=levels, cntrmass=contours,cntrSmooth=cntrSmooth, cntrcolors=cntrcolors,cmap=cmap,origin='lower', xrange=xrange,yrange=yrange,xlabel=xlabel, ylabel=ylabel,interpolation='nearest', retCumImage=True,aspect=aspect, conditional=conditional, cntrlw=cntrlw,cntrls=cntrls, justcontours=justcontours,zorder=5*justcontours, overplot=(gcf or onedhists or overplot or onedhistx or onedhisty)) else: cumimage= dens2d(hist.T,contours=contours, cntrcolors=cntrcolors, cmap=cmap,origin='lower', xrange=xrange,yrange=yrange,xlabel=xlabel, ylabel=ylabel,interpolation='nearest', conditional=conditional, retCumImage=True,aspect=aspect, cntrlw=cntrlw,cntrls=cntrls, overplot=(gcf or onedhists or overplot or onedhistx or onedhisty)) #Set axes and labels pyplot.axis(list(xrange)+list(yrange)) if not overplot: _add_axislabels(xlabel,ylabel) _add_ticks() binxs= [] xedge= edges[0] for ii in range(len(xedge)-1): binxs.append((xedge[ii]+xedge[ii+1])/2.) binxs= numpy.array(binxs) binys= [] yedge= edges[1] for ii in range(len(yedge)-1): binys.append((yedge[ii]+yedge[ii+1])/2.) binys= numpy.array(binys) cumInterp= interpolate.RectBivariateSpline(binxs,binys,cumimage.T, kx=1,ky=1) cums= [] for ii in range(len(x)): cums.append(cumInterp(x[ii],y[ii])[0,0]) cums= numpy.array(cums) plotx= x[cums > levels[-1]] ploty= y[cums > levels[-1]] if not len(plotx) == 0: if not weights == None: w8= weights[cums > levels[-1]] for ii in range(len(plotx)): plot(plotx[ii],ploty[ii],overplot=True, color='%.2f'%(1.-w8[ii]),*args,**kwargs) else: plot(plotx,ploty,overplot=True,zorder=1,*args,**kwargs) #Add onedhists if not (onedhists or onedhistx or onedhisty): if retAxes: return pyplot.gca() else: return None if onedhistx: histx, edges, patches= axHistx.hist(x,bins=onedhistsbins, normed=onedhistxnormed, weights=onedhistxweights, histtype=onedhisttype, range=sorted(xrange), color=onedhistcolor,fc=onedhistfc, ec=onedhistec,ls=onedhistls, lw=onedhistlw) if onedhisty: histy, edges, patches= axHisty.hist(y,bins=onedhistsbins, orientation='horizontal', weights=onedhistyweights, normed=onedhistynormed, histtype=onedhisttype, range=sorted(yrange), color=onedhistcolor,fc=onedhistfc, ec=onedhistec,ls=onedhistls, lw=onedhistlw) if onedhistx and not overplot: axHistx.set_xlim( axScatter.get_xlim() ) axHistx.set_ylim( 0, 1.2*numpy.amax(histx)) if onedhisty and not overplot: axHisty.set_ylim( axScatter.get_ylim() ) axHisty.set_xlim( 0, 1.2*numpy.amax(histy)) if not onedhistx: axHistx= None if not onedhisty: axHisty= None if retAxes: return (axScatter,axHistx,axHisty) else: return None
def _add_axislabels(xlabel,ylabel): """ NAME: _add_axislabels PURPOSE: add axis labels to the current figure INPUT: xlabel - (raw string!) x-axis label, LaTeX math mode, no $s needed ylabel - (raw string!) y-axis label, LaTeX math mode, no $s needed OUTPUT: (none; works on the current axes) HISTORY: 2009-12-23 - Written - Bovy (NYU) """ if xlabel != None: if xlabel[0] != '$': thisxlabel=r'$'+xlabel+'$' else: thisxlabel=xlabel pyplot.xlabel(thisxlabel) if ylabel != None: if ylabel[0] != '$': thisylabel=r'$'+ylabel+'$' else: thisylabel=ylabel pyplot.ylabel(thisylabel) def _add_ticks(xticks=True,yticks=True): """ NAME: _add_ticks PURPOSE: add minor axis ticks to a plot INPUT: (none; works on the current axes) OUTPUT: (none; works on the current axes) HISTORY: 2009-12-23 - Written - Bovy (NYU) """ ax=pyplot.gca() if xticks: xstep= ax.xaxis.get_majorticklocs() xstep= xstep[1]-xstep[0] ax.xaxis.set_minor_locator(ticker.MultipleLocator(xstep/5.)) if yticks: ystep= ax.yaxis.get_majorticklocs() ystep= ystep[1]-ystep[0] ax.yaxis.set_minor_locator(ticker.MultipleLocator(ystep/5.)) class GalPolarAxes(PolarAxes): ''' A variant of PolarAxes where theta increases clockwise ''' name = 'galpolar' class GalPolarTransform(PolarAxes.PolarTransform): def transform(self, tr): xy = numpy.zeros(tr.shape, numpy.float_) t = tr[:, 0:1] r = tr[:, 1:2] x = xy[:, 0:1] y = xy[:, 1:2] x[:] = r * numpy.cos(t) y[:] = -r * numpy.sin(t) return xy transform_non_affine = transform def inverted(self): return GalPolarAxes.InvertedGalPolarTransform() class InvertedGalPolarTransform(PolarAxes.InvertedPolarTransform): def transform(self, xy): x = xy[:, 0:1] y = xy[:, 1:] r = numpy.sqrt(x*x + y*y) theta = numpy.arctan2(y, x) return numpy.concatenate((theta, r), 1) def inverted(self): return GalPolarAxes.GalPolarTransform() def _set_lim_and_transforms(self): PolarAxes._set_lim_and_transforms(self) self.transProjection = self.GalPolarTransform() self.transData = ( self.transScale + self.transProjection + (self.transProjectionAffine + self.transAxes)) self._xaxis_transform = ( self.transProjection + self.PolarAffine(IdentityTransform(), Bbox.unit()) + self.transAxes) self._xaxis_text1_transform = ( self._theta_label1_position + self._xaxis_transform) self._yaxis_transform = ( Affine2D().scale(numpy.pi * 2.0, 1.0) + self.transData) self._yaxis_text1_transform = ( self._r_label1_position + Affine2D().scale(1.0 / 360.0, 1.0) + self._yaxis_transform) register_projection(GalPolarAxes)