spline.py

Go to the documentation of this file.
00001 """cubic spline handling, in a manner compatible with the API in Numeric Recipes"""
00002 _rcsid="$Id: spline.py,v 1.23 2005/07/13 14:24:58 mendenhall Exp $"
00003 
00004 __all__=["spline","splint","cubeinterpolate","RangeError",
00005 "spline_extension", "spline_extrapolate", "approximate_least_squares_spline" ]
00006 
00007 class RangeError(IndexError):
00008     "X out of input range in splint()"
00009 
00010 from Numeric import zeros, Float, searchsorted, array, asarray, take, clip
00011 import Numeric
00012 
00013 def spline(x, y, yp1=None, ypn=None):
00014     """y2 = spline(x_vals,y_vals, yp1=None, ypn=None) 
00015     returns the y2 table for the spline as needed by splint()"""
00016 
00017     n=len(x)
00018     u=zeros(n,Float)
00019     y2=zeros(n,Float)
00020     
00021     x=asarray(x, Float)
00022     y=asarray(y, Float)
00023     
00024     dx=x[1:]-x[:-1]
00025     dxi=1.0/dx
00026     dx2i=1.0/(x[2:]-x[:-2])
00027     dy=(y[1:]-y[:-1])
00028     siga=dx[:-1]*dx2i
00029     dydx=dy*dxi
00030     
00031     # u[i]=(y[i+1]-y[i])/float(x[i+1]-x[i]) - (y[i]-y[i-1])/float(x[i]-x[i-1])
00032     u[1:-1]=dydx[1:]-dydx[:-1] #this is an incomplete rendering of u... the rest requires recursion in the loop
00033     
00034     if yp1 is None:
00035         y2[0]=u[0]=0.0
00036     else:
00037         y2[0]= -0.5
00038         u[0]=(3.0*dxi[0])*(dy[0]*dxi[0] -yp1)
00039 
00040     for i in range(1,n-1):
00041         sig=siga[i-1]
00042         p=sig*y2[i-1]+2.0
00043         y2[i]=(sig-1.0)/p
00044         u[i]=(6.0*u[i]*dx2i[i-1] - sig*u[i-1])/p
00045 
00046     if ypn is None:
00047         qn=un=0.0
00048     else:
00049         qn= 0.5
00050         un=(3.0*dxi[-1])*(ypn- dy[-1]*dxi[-1] )
00051         
00052     y2[-1]=(un-qn*u[-2])/(qn*y2[-2]+1.0)
00053     for k in range(n-2,-1,-1):
00054         y2[k]=y2[k]*y2[k+1]+u[k]
00055 
00056     return y2
00057 
00058 def spline_extension(x, y, y2, xmin=None, xmax=None):
00059     """x, y, y2 = spline_extension(x_vals,y_vals, y2vals, xmin=None, xmax=None) 
00060     returns the x, y, y2 table for the spline as needed by splint() with adjustments to allow quadratic extrapolation 
00061     outside the range x[0]-x[-1], from xmin (or x[0] if xmin is None) to xmax (or x[-1] if xmax is None),
00062     working from x, y, y2 from an already-created spline"""
00063 
00064     xl=[x]
00065     yl=[y]
00066     y2l=[y2]
00067     
00068     if xmin is not None:
00069         h0=x[1]-x[0]
00070         h1=xmin-x[0]
00071         yextrap=y[0]+((y[1]-y[0])/h0 - h0*(y2[0]+2.0*y2[1])/6.0)*h1+y2[0]*h1*h1/2.0
00072         yl.insert(0, (yextrap,))
00073         xl.insert(0, (xmin,))
00074         y2l.insert(0, (y2[0],))
00075 
00076     if xmax is not None:
00077         h0=x[-1]-x[-2]
00078         h1=xmax-x[-1]
00079         yextrap=y[-1]+((y[-1]-y[-2])/h0 + h0*(2.0*y2[-2]+y2[-1])/6.0)*h1+y2[-1]*h1*h1/2.0
00080         yl.append((yextrap,))
00081         xl.append((xmax,))
00082         y2l.append((y2[-1],))
00083 
00084     return Numeric.concatenate(xl), Numeric.concatenate(yl), Numeric.concatenate(y2l)
00085 
00086 def spline_extrapolate(x, y, yp1=None, ypn=None, xmin=None, xmax=None):
00087     """x, y, y2 = spline_extrapolate(x_vals,y_vals, yp1=None, ypn=None, xmin=None, xmax=None) 
00088     returns the x, y, y2 table for the spline as needed by splint() with adjustments to allow quadratic extrapolation 
00089     outside the range x[0]-x[-1], from xmin (or x[0] if xmin is None) to xmax (or x[-1] if xmax is None)"""
00090 
00091     return spline_extension(x, y, spline(x,y,yp1,ypn), xmin, xmax) 
00092 
00093 import types
00094 
00095 def splint(xa, ya, y2a, x, derivs=False):
00096     """returns the interpolated from from the spline
00097     x can either be a scalar or a listable item, in which case a Numeric Float array will be
00098     returned and the multiple interpolations will be done somewhat more efficiently.
00099     If derivs is not False, return y, y', y'' instead of just y."""
00100     if type(x) is types.IntType or type(x) is types.FloatType: 
00101         if (x<xa[0] or x>xa[-1]):
00102             raise RangeError, "%f not in range (%f, %f) in splint()" % (x, xa[0], xa[-1])
00103              
00104         khi=max(searchsorted(xa,x),1)
00105         klo=khi-1
00106         h=float(xa[khi]-xa[klo])
00107         a=(xa[khi]-x)/h; b=1.0-a
00108         ylo=ya[klo]; yhi=ya[khi]; y2lo=y2a[klo]; y2hi=y2a[khi]
00109     else:
00110         #if we got here, we are processing a list, and should do so more efficiently
00111         if (min(x)<xa[0] or max(x)>xa[-1]):
00112             raise RangeError, "(%f, %f) not in range (%f, %f) in splint()" % (min(x), max(x), xa[0], xa[-1])
00113     
00114         npoints=len(x)
00115         khi=clip(searchsorted(xa,x),1,len(xa)) 
00116         
00117         klo=khi-1
00118         xhi=take(xa, khi)
00119         xlo=take(xa, klo)
00120         yhi=take(ya, khi)
00121         ylo=take(ya, klo)
00122         y2hi=take(y2a, khi)
00123         y2lo=take(y2a, klo)
00124         
00125         h=(xhi-xlo).astype(Float)
00126         a=(xhi-x)/h
00127         b=1.0-a
00128         
00129     y=a*ylo+b*yhi+((a*a*a-a)*y2lo+(b*b*b-b)*y2hi)*(h*h)/6.0
00130     if derivs:
00131         return y, (yhi-ylo)/h+((3*b*b-1)*y2hi-(3*a*a-1)*y2lo)*h/6.0, b*y2hi+a*y2lo
00132     else:
00133         return y
00134 
00135         
00136 def cubeinterpolate(xlist, ylist, x3):
00137     "find point at x3 given 4 points in given lists using exact cubic interpolation, not splining"
00138     x1,x2,x4,x5=xlist
00139     x2,x3,x4,x5=float(x2-x1),float(x3-x1),float(x4-x1),float(x5-x1)
00140     y1,y2,y4,y5=ylist
00141     y2,y4, y5=float(y2-y1),float(y4-y1),float(y5-y1)
00142     
00143     y3=(
00144             (x3*(x2**2*x5**2*(-x2 + x5)*y4 + x4**3*(x5**2*y2 - x2**2*y5) + x4**2*(-(x5**3*y2) + x2**3*y5) + 
00145                    x3**2*(x2*x5*(-x2 + x5)*y4 + x4**2*(x5*y2 - x2*y5) + x4*(-(x5**2*y2) + x2**2*y5)) + 
00146                    x3*(x2*x5*(x2**2 - x5**2)*y4 + x4**3*(-(x5*y2) + x2*y5) + x4*(x5**3*y2 - x2**3*y5))))/
00147                  (x2*(x2 - x4)*x4*(x2 - x5)*(x4 - x5)*x5)
00148     )+y1
00149     return y3
00150 
00151 from analysis import fitting_toolkit
00152 
00153 def approximate_least_squares_spline(xvals, yvals, nodelist=None, nodeindices=None, nodecount=None):
00154     """Compute an approximation to the true least-squares-spline to the dataset.  If the <nodelist> is not None,
00155     nodes will be placed near the x values indicated.  If <nodelist> is None,<nodecount> equally-spaced nodes will be placed.
00156     Explicit indices for the node placement can be given in nodeindices, which overrides everything else."""
00157     
00158     assert nodelist or nodecount or nodeindices, "Must have either a list of nodes or a node count"
00159         
00160     fitter=fitting_toolkit.polynomial_fit(2) #will fit quadratic sections
00161     
00162     if not nodeindices:
00163         if not nodelist: #make equally-spaced nodelist
00164             nodelist=Numeric.array(range(nodecount),Numeric.Float)*((xvals[-1]-xvals[0])/(nodecount-1))+xvals[0]
00165             nodelist[-1]=xvals[-1] #make sure no roundoff error clips the last point!
00166         else:
00167             nodecount=len(nodelist)
00168             
00169         nodeindices=Numeric.searchsorted(xvals, nodelist)
00170         boundindices=Numeric.searchsorted(xvals, (nodelist[1:]+nodelist[:-1])*0.5) #find halfway points
00171     else:
00172         boundindices=(nodeindices[1:]+nodeindices[:-1])//2 
00173         nodelist=Numeric.take(xvals, nodeindices)
00174         
00175     nodecount=len(nodeindices)
00176 
00177     ya=Numeric.zeros(nodecount,Numeric.Float)
00178 
00179     #fit first  chunk un-centered to get slope at start
00180     fitter.fit_data(xvals[:nodeindices[1]], yvals[:nodeindices[1]], xcenter=nodelist[0])
00181     ya[0]=fitter.funcparams[0]
00182     yp1=fitter.funcparams[1]
00183 
00184     for i in range(1,nodecount-1):
00185         chunkstart=boundindices[i-1]
00186         chunkend=boundindices[i]
00187         fitter.fit_data(xvals[chunkstart:chunkend], yvals[chunkstart:chunkend], xcenter=nodelist[i])
00188         ya[i]=fitter.funcparams[0]
00189     
00190     #fit last  chunk un-centered to get slope at end
00191     fitter.fit_data(xvals[nodeindices[-2]:], yvals[nodeindices[-2]:], xcenter=nodelist[-1])
00192     ya[-1]=fitter.funcparams[0]
00193     ypn=fitter.funcparams[1]
00194             
00195     y2a=spline(nodelist, ya, yp1=yp1, ypn=ypn)
00196     return nodelist, ya, y2a
00197     
00198 
00199 if __name__=="__main__":
00200     import traceback
00201     testlist=((0,1), (1,1),(2,3),(3,4),(4,2),(5,6),(7,9),(10,6),(15,2), (16,-1))
00202     #testlist=((0,0), (1,1),(2,4),(3,9),(4,16),(5,25),(7,49),(10,100),(15,225), (16,256))
00203     xlist=[i[0] for i in testlist]
00204     ylist=[i[1] for i in testlist]
00205     print "\n\nStarting splint tests...\n", testlist
00206     y2=spline(xlist,ylist, yp1=-5, ypn=10)
00207     r=(0,1,2,3.5, 3.7, 4,6,7,2,8,9,10,11, 5, 12,13,14, 15, 16)
00208     v=splint(xlist, ylist, y2, r)   
00209     print y2
00210     for i in range(len(r)):
00211         print "%.1f %.3f %.3f" % (r[i], v[i], splint(xlist, ylist, y2, r[i]))
00212     
00213     v, vp, vpp=splint(xlist, ylist, y2, r, derivs=True) 
00214     for i in range(len(r)):
00215         print "%5.1f %10.3f %10.3f %10.3f" % (r[i], v[i], vp[i], vpp[i])
00216     
00217     print "The next operations should print exceptions"
00218     try:
00219         splint(xlist, ylist, y2, 100.0)
00220     except:
00221         traceback.print_exc()
00222     try:
00223         splint(xlist, ylist, y2, (1,2,2.5, 3,-5, 4,5,6,7,8,9,10,11,12,13,14,15,16,17,18))
00224     except:
00225         traceback.print_exc()
00226     
00227     try:
00228         
00229         xx, yy, yy2=spline_extrapolate(xlist,ylist, yp1=None, ypn=-2.5, xmin=xlist[0]-2, xmax=xlist[-1]+2)
00230         import graphite
00231         import Numeric
00232         g=graphite.Graph()
00233         ds1=graphite.Dataset()
00234         ds1.x=xx
00235         ds1.y=yy
00236         g.datasets.append(ds1)
00237         f1 = graphite.PointPlot()
00238         f1.lineStyle = None
00239         f1.symbol = graphite.CircleSymbol
00240         f1.symbolStyle=graphite.SymbolStyle(size=5, fillColor=graphite.red, edgeColor=graphite.red)
00241         g.formats=[]
00242         g.formats.append(f1)
00243         finex=Numeric.array(range(-20,181),Float)*0.1
00244         finey=splint(xx, yy, yy2, finex)
00245         ds2=graphite.Dataset()
00246         ds2.x=finex
00247         ds2.y=finey
00248         g.datasets.append(ds2)
00249         f2 = graphite.PointPlot()
00250         f2.lineStyle = graphite.LineStyle(width=1, color=graphite.green, kind=graphite.SOLID)
00251         f2.symbol = None
00252         g.formats.append(f2)
00253         g.bottom=400
00254         g.right=700
00255         try:
00256             graphite.genOutput(g,'QD', size=(800,500))
00257         except:
00258             graphite.genOutput(g,'PDF', size=(800,500))
00259     except:
00260         import traceback
00261         traceback.print_exc()
00262         print "Graphite not available... plotted results not shown"
00263         

Generated on Wed Nov 21 10:18:31 2007 for analysis by  doxygen 1.5.4