00001 """cubic spline handling, in a manner compatible with the API in Numeric Recipes"""
00002 _rcsid="$Id: spline.py,v 1.23 2005/07/13 14:24:58 mendenhall Exp $"
00003
00004 __all__=["spline","splint","cubeinterpolate","RangeError",
00005 "spline_extension", "spline_extrapolate", "approximate_least_squares_spline" ]
00006
00007 class RangeError(IndexError):
00008 "X out of input range in splint()"
00009
00010 from Numeric import zeros, Float, searchsorted, array, asarray, take, clip
00011 import Numeric
00012
00013 def spline(x, y, yp1=None, ypn=None):
00014 """y2 = spline(x_vals,y_vals, yp1=None, ypn=None)
00015 returns the y2 table for the spline as needed by splint()"""
00016
00017 n=len(x)
00018 u=zeros(n,Float)
00019 y2=zeros(n,Float)
00020
00021 x=asarray(x, Float)
00022 y=asarray(y, Float)
00023
00024 dx=x[1:]-x[:-1]
00025 dxi=1.0/dx
00026 dx2i=1.0/(x[2:]-x[:-2])
00027 dy=(y[1:]-y[:-1])
00028 siga=dx[:-1]*dx2i
00029 dydx=dy*dxi
00030
00031
00032 u[1:-1]=dydx[1:]-dydx[:-1]
00033
00034 if yp1 is None:
00035 y2[0]=u[0]=0.0
00036 else:
00037 y2[0]= -0.5
00038 u[0]=(3.0*dxi[0])*(dy[0]*dxi[0] -yp1)
00039
00040 for i in range(1,n-1):
00041 sig=siga[i-1]
00042 p=sig*y2[i-1]+2.0
00043 y2[i]=(sig-1.0)/p
00044 u[i]=(6.0*u[i]*dx2i[i-1] - sig*u[i-1])/p
00045
00046 if ypn is None:
00047 qn=un=0.0
00048 else:
00049 qn= 0.5
00050 un=(3.0*dxi[-1])*(ypn- dy[-1]*dxi[-1] )
00051
00052 y2[-1]=(un-qn*u[-2])/(qn*y2[-2]+1.0)
00053 for k in range(n-2,-1,-1):
00054 y2[k]=y2[k]*y2[k+1]+u[k]
00055
00056 return y2
00057
00058 def spline_extension(x, y, y2, xmin=None, xmax=None):
00059 """x, y, y2 = spline_extension(x_vals,y_vals, y2vals, xmin=None, xmax=None)
00060 returns the x, y, y2 table for the spline as needed by splint() with adjustments to allow quadratic extrapolation
00061 outside the range x[0]-x[-1], from xmin (or x[0] if xmin is None) to xmax (or x[-1] if xmax is None),
00062 working from x, y, y2 from an already-created spline"""
00063
00064 xl=[x]
00065 yl=[y]
00066 y2l=[y2]
00067
00068 if xmin is not None:
00069 h0=x[1]-x[0]
00070 h1=xmin-x[0]
00071 yextrap=y[0]+((y[1]-y[0])/h0 - h0*(y2[0]+2.0*y2[1])/6.0)*h1+y2[0]*h1*h1/2.0
00072 yl.insert(0, (yextrap,))
00073 xl.insert(0, (xmin,))
00074 y2l.insert(0, (y2[0],))
00075
00076 if xmax is not None:
00077 h0=x[-1]-x[-2]
00078 h1=xmax-x[-1]
00079 yextrap=y[-1]+((y[-1]-y[-2])/h0 + h0*(2.0*y2[-2]+y2[-1])/6.0)*h1+y2[-1]*h1*h1/2.0
00080 yl.append((yextrap,))
00081 xl.append((xmax,))
00082 y2l.append((y2[-1],))
00083
00084 return Numeric.concatenate(xl), Numeric.concatenate(yl), Numeric.concatenate(y2l)
00085
00086 def spline_extrapolate(x, y, yp1=None, ypn=None, xmin=None, xmax=None):
00087 """x, y, y2 = spline_extrapolate(x_vals,y_vals, yp1=None, ypn=None, xmin=None, xmax=None)
00088 returns the x, y, y2 table for the spline as needed by splint() with adjustments to allow quadratic extrapolation
00089 outside the range x[0]-x[-1], from xmin (or x[0] if xmin is None) to xmax (or x[-1] if xmax is None)"""
00090
00091 return spline_extension(x, y, spline(x,y,yp1,ypn), xmin, xmax)
00092
00093 import types
00094
00095 def splint(xa, ya, y2a, x, derivs=False):
00096 """returns the interpolated from from the spline
00097 x can either be a scalar or a listable item, in which case a Numeric Float array will be
00098 returned and the multiple interpolations will be done somewhat more efficiently.
00099 If derivs is not False, return y, y', y'' instead of just y."""
00100 if type(x) is types.IntType or type(x) is types.FloatType:
00101 if (x<xa[0] or x>xa[-1]):
00102 raise RangeError, "%f not in range (%f, %f) in splint()" % (x, xa[0], xa[-1])
00103
00104 khi=max(searchsorted(xa,x),1)
00105 klo=khi-1
00106 h=float(xa[khi]-xa[klo])
00107 a=(xa[khi]-x)/h; b=1.0-a
00108 ylo=ya[klo]; yhi=ya[khi]; y2lo=y2a[klo]; y2hi=y2a[khi]
00109 else:
00110
00111 if (min(x)<xa[0] or max(x)>xa[-1]):
00112 raise RangeError, "(%f, %f) not in range (%f, %f) in splint()" % (min(x), max(x), xa[0], xa[-1])
00113
00114 npoints=len(x)
00115 khi=clip(searchsorted(xa,x),1,len(xa))
00116
00117 klo=khi-1
00118 xhi=take(xa, khi)
00119 xlo=take(xa, klo)
00120 yhi=take(ya, khi)
00121 ylo=take(ya, klo)
00122 y2hi=take(y2a, khi)
00123 y2lo=take(y2a, klo)
00124
00125 h=(xhi-xlo).astype(Float)
00126 a=(xhi-x)/h
00127 b=1.0-a
00128
00129 y=a*ylo+b*yhi+((a*a*a-a)*y2lo+(b*b*b-b)*y2hi)*(h*h)/6.0
00130 if derivs:
00131 return y, (yhi-ylo)/h+((3*b*b-1)*y2hi-(3*a*a-1)*y2lo)*h/6.0, b*y2hi+a*y2lo
00132 else:
00133 return y
00134
00135
00136 def cubeinterpolate(xlist, ylist, x3):
00137 "find point at x3 given 4 points in given lists using exact cubic interpolation, not splining"
00138 x1,x2,x4,x5=xlist
00139 x2,x3,x4,x5=float(x2-x1),float(x3-x1),float(x4-x1),float(x5-x1)
00140 y1,y2,y4,y5=ylist
00141 y2,y4, y5=float(y2-y1),float(y4-y1),float(y5-y1)
00142
00143 y3=(
00144 (x3*(x2**2*x5**2*(-x2 + x5)*y4 + x4**3*(x5**2*y2 - x2**2*y5) + x4**2*(-(x5**3*y2) + x2**3*y5) +
00145 x3**2*(x2*x5*(-x2 + x5)*y4 + x4**2*(x5*y2 - x2*y5) + x4*(-(x5**2*y2) + x2**2*y5)) +
00146 x3*(x2*x5*(x2**2 - x5**2)*y4 + x4**3*(-(x5*y2) + x2*y5) + x4*(x5**3*y2 - x2**3*y5))))/
00147 (x2*(x2 - x4)*x4*(x2 - x5)*(x4 - x5)*x5)
00148 )+y1
00149 return y3
00150
00151 from analysis import fitting_toolkit
00152
00153 def approximate_least_squares_spline(xvals, yvals, nodelist=None, nodeindices=None, nodecount=None):
00154 """Compute an approximation to the true least-squares-spline to the dataset. If the <nodelist> is not None,
00155 nodes will be placed near the x values indicated. If <nodelist> is None,<nodecount> equally-spaced nodes will be placed.
00156 Explicit indices for the node placement can be given in nodeindices, which overrides everything else."""
00157
00158 assert nodelist or nodecount or nodeindices, "Must have either a list of nodes or a node count"
00159
00160 fitter=fitting_toolkit.polynomial_fit(2)
00161
00162 if not nodeindices:
00163 if not nodelist:
00164 nodelist=Numeric.array(range(nodecount),Numeric.Float)*((xvals[-1]-xvals[0])/(nodecount-1))+xvals[0]
00165 nodelist[-1]=xvals[-1]
00166 else:
00167 nodecount=len(nodelist)
00168
00169 nodeindices=Numeric.searchsorted(xvals, nodelist)
00170 boundindices=Numeric.searchsorted(xvals, (nodelist[1:]+nodelist[:-1])*0.5)
00171 else:
00172 boundindices=(nodeindices[1:]+nodeindices[:-1])//2
00173 nodelist=Numeric.take(xvals, nodeindices)
00174
00175 nodecount=len(nodeindices)
00176
00177 ya=Numeric.zeros(nodecount,Numeric.Float)
00178
00179
00180 fitter.fit_data(xvals[:nodeindices[1]], yvals[:nodeindices[1]], xcenter=nodelist[0])
00181 ya[0]=fitter.funcparams[0]
00182 yp1=fitter.funcparams[1]
00183
00184 for i in range(1,nodecount-1):
00185 chunkstart=boundindices[i-1]
00186 chunkend=boundindices[i]
00187 fitter.fit_data(xvals[chunkstart:chunkend], yvals[chunkstart:chunkend], xcenter=nodelist[i])
00188 ya[i]=fitter.funcparams[0]
00189
00190
00191 fitter.fit_data(xvals[nodeindices[-2]:], yvals[nodeindices[-2]:], xcenter=nodelist[-1])
00192 ya[-1]=fitter.funcparams[0]
00193 ypn=fitter.funcparams[1]
00194
00195 y2a=spline(nodelist, ya, yp1=yp1, ypn=ypn)
00196 return nodelist, ya, y2a
00197
00198
00199 if __name__=="__main__":
00200 import traceback
00201 testlist=((0,1), (1,1),(2,3),(3,4),(4,2),(5,6),(7,9),(10,6),(15,2), (16,-1))
00202
00203 xlist=[i[0] for i in testlist]
00204 ylist=[i[1] for i in testlist]
00205 print "\n\nStarting splint tests...\n", testlist
00206 y2=spline(xlist,ylist, yp1=-5, ypn=10)
00207 r=(0,1,2,3.5, 3.7, 4,6,7,2,8,9,10,11, 5, 12,13,14, 15, 16)
00208 v=splint(xlist, ylist, y2, r)
00209 print y2
00210 for i in range(len(r)):
00211 print "%.1f %.3f %.3f" % (r[i], v[i], splint(xlist, ylist, y2, r[i]))
00212
00213 v, vp, vpp=splint(xlist, ylist, y2, r, derivs=True)
00214 for i in range(len(r)):
00215 print "%5.1f %10.3f %10.3f %10.3f" % (r[i], v[i], vp[i], vpp[i])
00216
00217 print "The next operations should print exceptions"
00218 try:
00219 splint(xlist, ylist, y2, 100.0)
00220 except:
00221 traceback.print_exc()
00222 try:
00223 splint(xlist, ylist, y2, (1,2,2.5, 3,-5, 4,5,6,7,8,9,10,11,12,13,14,15,16,17,18))
00224 except:
00225 traceback.print_exc()
00226
00227 try:
00228
00229 xx, yy, yy2=spline_extrapolate(xlist,ylist, yp1=None, ypn=-2.5, xmin=xlist[0]-2, xmax=xlist[-1]+2)
00230 import graphite
00231 import Numeric
00232 g=graphite.Graph()
00233 ds1=graphite.Dataset()
00234 ds1.x=xx
00235 ds1.y=yy
00236 g.datasets.append(ds1)
00237 f1 = graphite.PointPlot()
00238 f1.lineStyle = None
00239 f1.symbol = graphite.CircleSymbol
00240 f1.symbolStyle=graphite.SymbolStyle(size=5, fillColor=graphite.red, edgeColor=graphite.red)
00241 g.formats=[]
00242 g.formats.append(f1)
00243 finex=Numeric.array(range(-20,181),Float)*0.1
00244 finey=splint(xx, yy, yy2, finex)
00245 ds2=graphite.Dataset()
00246 ds2.x=finex
00247 ds2.y=finey
00248 g.datasets.append(ds2)
00249 f2 = graphite.PointPlot()
00250 f2.lineStyle = graphite.LineStyle(width=1, color=graphite.green, kind=graphite.SOLID)
00251 f2.symbol = None
00252 g.formats.append(f2)
00253 g.bottom=400
00254 g.right=700
00255 try:
00256 graphite.genOutput(g,'QD', size=(800,500))
00257 except:
00258 graphite.genOutput(g,'PDF', size=(800,500))
00259 except:
00260 import traceback
00261 traceback.print_exc()
00262 print "Graphite not available... plotted results not shown"
00263