python/ext/py-source.c: add seek, thanks @davebrent for the heads up
authorPaul Brossier <piem@piem.org>
Sun, 21 Sep 2014 00:42:08 +0000 (21:42 -0300)
committerPaul Brossier <piem@piem.org>
Sun, 21 Sep 2014 00:42:08 +0000 (21:42 -0300)
python/ext/py-source.c
python/tests/test_source.py

index c6cad07..f9f972f 100644 (file)
@@ -69,6 +69,11 @@ static char Py_source_close_doc[] = ""
 "\n"
 "Close this source now.";
 
+static char Py_source_seek_doc[] = ""
+"x.seek(position)\n"
+"\n"
+"Seek to resampled frame position.";
+
 static PyObject *
 Py_source_new (PyTypeObject * pytype, PyObject * args, PyObject * kwds)
 {
@@ -238,6 +243,25 @@ Pyaubio_source_close (Py_source *self, PyObject *unused)
   Py_RETURN_NONE;
 }
 
+static PyObject *
+Pyaubio_source_seek (Py_source *self, PyObject *args)
+{
+  uint_t err = 0;
+
+  uint_t position;
+  if (!PyArg_ParseTuple (args, "I", &position)) {
+    return NULL;
+  }
+
+  err = aubio_source_seek(self->o, position);
+  if (err != 0) {
+    PyErr_SetString (PyExc_ValueError,
+        "error when seeking in source");
+    return NULL;
+  }
+  Py_RETURN_NONE;
+}
+
 static PyMethodDef Py_source_methods[] = {
   {"get_samplerate", (PyCFunction) Pyaubio_source_get_samplerate,
     METH_NOARGS, Py_source_get_samplerate_doc},
@@ -249,6 +273,8 @@ static PyMethodDef Py_source_methods[] = {
     METH_NOARGS, Py_source_do_multi_doc},
   {"close", (PyCFunction) Pyaubio_source_close,
     METH_NOARGS, Py_source_close_doc},
+  {"seek", (PyCFunction) Pyaubio_source_seek,
+    METH_VARARGS, Py_source_seek_doc},
   {NULL} /* sentinel */
 };
 
index ba05d06..f571a14 100755 (executable)
@@ -32,7 +32,7 @@ class aubio_source_test_case(aubio_source_test_case_base):
 
 class aubio_source_read_test_case(aubio_source_test_case_base):
 
-    def read_from_sink(self, f):
+    def read_from_source(self, f):
         total_frames = 0
         while True:
             vec, read = f()
@@ -42,25 +42,26 @@ class aubio_source_read_test_case(aubio_source_test_case_base):
         print "(", total_frames, "frames", "in",
         print total_frames / f.hop_size, "blocks", "at", "%dHz" % f.samplerate, ")",
         print "from", f.uri
+        return total_frames
 
     def test_samplerate_hopsize(self):
         for p in list_of_sounds:
             for samplerate, hop_size in zip([0, 44100, 8000, 32000], [ 512, 512, 64, 256]):
                 f = source(p, samplerate, hop_size)
                 assert f.samplerate != 0
-                self.read_from_sink(f)
+                self.read_from_source(f)
 
     def test_samplerate_none(self):
         for p in list_of_sounds:
             f = source(p)
             assert f.samplerate != 0
-            self.read_from_sink(f)
+            self.read_from_source(f)
 
     def test_samplerate_0(self):
         for p in list_of_sounds:
             f = source(p, 0)
             assert f.samplerate != 0
-            self.read_from_sink(f)
+            self.read_from_source(f)
 
     def test_wrong_samplerate(self):
         for p in list_of_sounds:
@@ -85,11 +86,23 @@ class aubio_source_read_test_case(aubio_source_test_case_base):
             f = source(p, 0, 0)
             assert f.samplerate != 0
             assert f.hop_size != 0
-            self.read_from_sink(f)
+            self.read_from_source(f)
+
+    def test_seek_to_half(self):
+        from random import randint
+        for p in list_of_sounds:
+            f = source(p, 0, 0)
+            assert f.samplerate != 0
+            assert f.hop_size != 0
+            a = self.read_from_source(f)
+            c = randint(0, a)
+            f.seek(c)
+            b = self.read_from_source(f)
+            assert a == b + c
 
 class aubio_source_readmulti_test_case(aubio_source_read_test_case):
 
-    def read_from_sink(self, f):
+    def read_from_source(self, f):
         total_frames = 0
         while True:
             vec, read = f.do_multi()
@@ -100,6 +113,7 @@ class aubio_source_readmulti_test_case(aubio_source_read_test_case):
         print f.channels, "channels and",
         print total_frames / f.hop_size, "blocks", "at", "%dHz" % f.samplerate, ")",
         print "from", f.uri
+        return total_frames
 
 if __name__ == '__main__':
     from unittest import main