diff --git a/canopen/network.py b/canopen/network.py index 6a8d95f6..4474c192 100644 --- a/canopen/network.py +++ b/canopen/network.py @@ -107,7 +107,8 @@ def connect(self, *args, **kwargs) -> Network: if self.bus is None: self.bus = can.Bus(*args, **kwargs) logger.info("Connected to '%s'", self.bus.channel_info) - self.notifier = can.Notifier(self.bus, self.listeners, self.NOTIFIER_CYCLE) + if self.notifier is None: + self.notifier = can.Notifier(self.bus, self.listeners, self.NOTIFIER_CYCLE) return self def disconnect(self) -> None: @@ -123,7 +124,11 @@ def disconnect(self) -> None: if self.bus is not None: self.bus.shutdown() self.bus = None - self.check() + try: + self.check() + finally: + # Release notifier after check + self.notifier = None def __enter__(self): return self diff --git a/test/test_network.py b/test/test_network.py index cd65ea71..d488dc41 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -285,6 +285,34 @@ def wait_for_periodicity(): if msg is not None: self.assertIsNone(bus.recv(PERIOD)) + def test_network_connect_does_not_recreate_notifier(self): + self.network.connect(interface="virtual") + self.addCleanup(self.network.disconnect) + notifier1 = self.network.notifier + self.assertIsNotNone(notifier1) + # Calling connect() again should reuse the existing notifier + self.network.connect(interface="virtual") + self.assertIs(self.network.notifier, notifier1) + + def test_network_disconnect_releases_notifier(self): + self.network.connect(interface="virtual") + self.assertIsNotNone(self.network.notifier) + self.network.disconnect() + self.assertIsNone(self.network.notifier) + + def test_network_disconnect_releases_notifier_on_exception(self): + self.network.connect(interface="virtual") + + class Custom(Exception): + pass + + self.network.notifier.exception = Custom("fake") + with self.assertRaises(Custom): + with self.assertLogs(level=logging.ERROR): + self.network.disconnect() + # Notifier must be released even when check() raises + self.assertIsNone(self.network.notifier) + class TestScanner(unittest.TestCase): TIMEOUT = 0.1