source: fact/tools/pyscripts/pyfact/fir_filter.py

Last change on this file was 14486, checked in by neise, 12 years ago
new class FastSlidingAverage with Weave! 50Hz .. yeah
  • Property svn:executable set to *
File size: 8.3 KB
Line 
1#!/usr/bin/python -tt
2#
3# Dominik Neise, Werner Lustermann
4# TU Dortmund, ETH Zurich
5#
6import numpy as np
7from scipy import signal
8
9# For FastSlidingAverage
10from scipy import weave
11from scipy.weave import converters
12
13class FirFilter(object):
14 """ finite impulse response filter
15
16 """
17
18 def __init__(self, b, a, name = 'general FIR filter'):
19 """
20 See `scipy.signal.lfilter() <http://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html#scipy.signal.lfilter>`_
21 *b* The numerator coefficient vector (1D)
22
23 *a* The denominator coefficient vector (1D).
24
25 If a[0] is not 1, then both a and b are normalized by a[0].
26 """
27 self.a = a
28 self.b = b
29 self.name = name
30
31 def __call__(self, data):
32 """ Apply generic FIR filter to *data* using scipy.signal.lfilter()
33
34 *data* 1D or 2D numpy array
35
36 remark:
37 I did not understand how to use the initial filter conditions of lfilter()
38 to produce the output, I expected.
39 So I apply the filters as follows.
40 the filter *delay* is equal to its length-1
41 Then I extend the input data by this delay-length, adding copies of the
42 first value.
43 Then the filter runs ovter this extended data.
44 The output will have a filter artifact in the first samples, which
45 will be cut off anyway, because they were artificially added before.
46 """
47 delay = max(len(self.a),len(self.b))-1
48
49 if ( data.ndim == 1):
50 initial = np.ones(delay)
51 initial *= data[0]
52 elif ( data.ndim == 2):
53 initial = np.ones( (data.shape[0], delay) )
54 for i in range(data.shape[0]):
55 initial[i,:] *= data[i,0]
56 else:
57 print 'HELP.'
58 pass
59 data = np.hstack( (initial,data) )
60
61 filtered= signal.lfilter(self.b, self.a, data)
62 if ( data.ndim == 1):
63 filtered = filtered[delay:]
64 elif ( data.ndim == 2):
65 filtered = filtered[:,delay:]
66
67 return filtered
68
69 def __str__(self):
70 s = self.name + '\n'
71 s += 'initial condition for filter: signal@rest = 1st sample\n'
72 s += 'filter, coefficients:\n'
73 s += 'nominator ' + str(self.b) + '\n'
74 s += 'denominator ' + str(self.a)
75 return s
76
77class SlidingAverage(FirFilter):
78 """ data smoothing in the time domain with a sliding average
79
80 """
81
82 def __init__(self, length=8):
83 """ initialize the object
84 length: lenght of the averaging window
85
86 """
87 b = np.ones(length)
88 a = np.zeros(length)
89 if length > 0:
90 a[0] = len(b)
91 FirFilter.__init__(self, b, a, 'sliding average')
92
93
94
95
96class FastSlidingAverage( object ):
97 def __init__(self, shape, length=8):
98 """ initialize the object
99 length: lenght of the averaging window
100 """
101 self.length = length
102 # allocate memory for the filtered data once
103 self.filtered = np.zeros( shape, np.float64 )
104 self.shape = shape
105
106 def __call__(self, data):
107 if self.shape != data.shape:
108 raise TypeException('data has wrong shape')
109
110 length = self.length
111 numpix = data.shape[0]
112 numslices = data.shape[1]
113
114 filtered = self.filtered
115 cppcode = """
116 double sum = 0;
117 for (int pix=0; pix<numpix; pix++)
118 {
119 for ( int sl=0; sl<numslices-length; ++sl)
120 {
121 for ( int i=0; i<length; ++i)
122 {
123 sum += data(pix,sl+i);
124 }
125 filtered(pix,sl) = sum/length;
126 }
127 }
128 """
129
130 ## this seems to run a little bit faster.
131
132 cppcode2 = """
133 double sum = 0;
134 for (int pix=0; pix<numpix; pix++)
135 {
136 sum = 0;
137 for ( int i=0; i<length-1; ++i)
138 {
139 sum += data(pix,i);
140 }
141 for ( int sl=length-1; sl<numslices-length; ++sl)
142 {
143 sum += data(pix,sl);
144 filtered(pix,sl) = sum/length;
145 sum -= data(pix,sl-(length-1) );
146 }
147 }
148 """
149
150 weave.inline( cppcode2,
151 [ 'length' , 'numpix', 'numslices', 'data', 'filtered'],
152 type_converters=converters.blitz)
153
154 return filtered
155
156class CFD(FirFilter):
157 """ Constant Fraction Discriminator """
158 def __init__(self, length = 10., ratio = 0.75):
159
160 b = np.zeros(length)
161 a = np.zeros(length)
162 if length > 0:
163 b[0] = -1. * ratio
164 b[length-1] = 1.
165 a[0] = 1.
166 FirFilter.__init__(self, b, a, 'constant fraction discriminator')
167
168
169class RemoveSignal(FirFilter):
170 """ estimator to identify DRS4 spikes
171
172 """
173
174 def __init__(self):
175 """ initialize the object """
176
177 b = np.array((-0.5, 1., -0.5))
178 a = np.zeros(len(b))
179 a[0] = 1.0
180 FirFilter.__init__(self, b, a, 'remove signal')
181
182
183def _test_SlidingAverage():
184 """ test the sliding average function
185 use a step function as input
186
187 """
188 from plotters import Plotter
189 from generator import SignalGenerator
190 generate = SignalGenerator()
191 plot = Plotter('_test_SlidingAverage')
192
193 safilter = SlidingAverage(8)
194 print safilter
195
196 signal = generate('len 100 noise 1.5 step 20 10 50')
197 filtered = safilter(signal)
198 plot( [signal, filtered] , ['original', 'filtered'] )
199
200 raw_input('press any key to go on')
201 plt.close(plot.figure)
202
203
204def _test_SlidingAverage2():
205 """ test the sliding average function
206 use a step function as input
207 """
208 from plotters import Plotter
209 from generator import SignalGenerator
210 generate = SignalGenerator()
211 plot = Plotter('_test_SlidingAverage')
212
213 safilter = SlidingAverage(8)
214 print safilter
215
216 signal = np.ones( (6,20) ) * 3.0
217 filtered = safilter(signal)
218 #plot( [signal[0], filtered[0]] , ['original', 'filtered'] )
219
220 raw_input('press any key to go on')
221 plt.close(plot.figure)
222
223
224
225def _test_CFD():
226 """ test the remove signal function
227
228 """
229 from plotters import Plotter
230 from generator import SignalGenerator
231 generate = SignalGenerator()
232 plot = Plotter('_test_CFD')
233
234 sa = SlidingAverage(3)
235 print 'I apply a weak smooting with a filter length of 3'
236 cfd = CFD(8, 0.6)
237 print cfd
238
239 signal = generate('len 100 noise 1.5 bsl -20 triangle 30 30 8 50')
240 filtered = cfd(sa(signal))
241 plot( [signal, filtered] , ['original', 'filtered'] )
242
243 raw_input('press any key to go on')
244 plt.close(plot.figure)
245
246def _test_RemoveSignal():
247 """ test the remove signal function
248
249 """
250 from plotters import Plotter
251 from generator import SignalGenerator
252 generate = SignalGenerator()
253 plot = Plotter('_test_RemoveSignal')
254
255 remove_signal = RemoveSignal()
256 print remove_signal
257
258 signal = generate('len 100 noise 2 bsl -20 triangle 20 30 8 40 50 30 spike 50 50 15 50 80 50')
259 filtered = remove_signal(signal)
260 plot( [signal, filtered] , ['original', 'filtered'] )
261
262 raw_input('press any key to go on')
263 plt.close(plot.figure)
264
265def _test(filter_type, sig, noise_sigma = 1.):
266 from plotters import Plotter
267
268 filt = filter_type
269 samples = len(sig)
270 # add noise to the signal
271 sig += np.random.randn(samples) * noise_sigma
272
273 plot = Plotter('_test with ' + str(filt.name))
274 plot( [sig, filt(sig)], ['original', 'filtered'] )
275 raw_input('press any key to go on')
276 plt.close(plot.figure)
277
278if __name__ == '__main__':
279 import matplotlib.pyplot as plt
280 """ test the class """
281
282 _test_SlidingAverage()
283 _test_CFD()
284 _test_RemoveSignal()
285
286 tsig = np.ones(100)
287 _test(filter_type=SlidingAverage(8), sig=tsig, noise_sigma=3.)
Note: See TracBrowser for help on using the repository browser.