diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py
index 134f28f609dfb99feccc90812e0ed5b4278c00d6..e481fb973f0920b74508fd34d27014c8f55126af 100644
--- a/src/lambo/segmenter/lambo.py
+++ b/src/lambo/segmenter/lambo.py
@@ -202,6 +202,7 @@ class Lambo():
         with torch.no_grad():
             X = [x.to(self.device) for x in X]
             Y = self.model(*X)
+        Y = Y.to('cpu')
         
         # perform postprocessing
         decisions = self.model.postprocessing(Y, text)