884
883
assert dz.shape == dshape
885
884
assert dz.dtype.type == dtype
886
class _TestCorrelate(TestCase):
887
def _setup(self, dt):
888
self.x = np.array([1, 2, 3, 4, 5], dtype=dt)
889
self.y = np.array([-1, -2, -3], dtype=dt)
890
self.z1 = np.array([ -3., -8., -14., -20., -26., -14., -5.], dtype=dt)
891
self.z2 = np.array([ -5., -14., -26., -20., -14., -8., -3.], dtype=dt)
893
def test_float(self):
894
self._setup(np.float)
895
z = np.correlate(self.x, self.y, 'full', old_behavior=self.old_behavior)
896
assert_array_almost_equal(z, self.z1)
897
z = np.correlate(self.y, self.x, 'full', old_behavior=self.old_behavior)
898
assert_array_almost_equal(z, self.z2)
900
def test_object(self):
902
z = np.correlate(self.x, self.y, 'full', old_behavior=self.old_behavior)
903
assert_array_almost_equal(z, self.z1)
904
z = np.correlate(self.y, self.x, 'full', old_behavior=self.old_behavior)
905
assert_array_almost_equal(z, self.z2)
907
class TestCorrelate(_TestCorrelate):
909
def _setup(self, dt):
910
# correlate uses an unconventional definition so that correlate(a, b)
911
# == correlate(b, a), so force the corresponding outputs to be the same
913
_TestCorrelate._setup(self, dt)
917
def test_complex(self):
918
x = np.array([1, 2, 3, 4+1j], dtype=np.complex)
919
y = np.array([-1, -2j, 3+1j], dtype=np.complex)
920
r_z = np.array([3+1j, 6, 8-1j, 9+1j, -1-8j, -4-1j], dtype=np.complex)
921
z = np.correlate(x, y, 'full')
922
assert_array_almost_equal(z, r_z)
925
def test_float(self):
926
_TestCorrelate.test_float(self)
929
def test_object(self):
930
_TestCorrelate.test_object(self)
932
class TestCorrelateNew(_TestCorrelate):
934
def test_complex(self):
935
x = np.array([1, 2, 3, 4+1j], dtype=np.complex)
936
y = np.array([-1, -2j, 3+1j], dtype=np.complex)
937
r_z = np.array([3-1j, 6, 8+1j, 11+5j, -5+8j, -4-1j], dtype=np.complex)
938
#z = np.acorrelate(x, y, 'full')
939
#assert_array_almost_equal(z, r_z)
941
r_z = r_z[::-1].conjugate()
942
z = np.correlate(y, x, 'full', old_behavior=self.old_behavior)
943
assert_array_almost_equal(z, r_z)
947
x = np.arange(6).reshape((2, 3))
948
assert_array_equal(np.argwhere(x > 1),
955
assert_equal(np.argwhere([4, 0, 2, 1, 3]), [[0], [2], [3], [4]])
888
957
if __name__ == "__main__":
889
958
run_module_suite()