From 34176b3751a20fcb3a08b805ee2bf27f835c9983 Mon Sep 17 00:00:00 2001 From: umerkhan95 <96595386+umerkhan95@users.noreply.github.com> Date: Thu, 12 Mar 2026 06:17:14 +0100 Subject: [PATCH] Add configurable batch_size to CNN to fix GPU out of memory on large directories The batch size was hardcoded at 64 with no way to lower it, which caused OOM on GPUs with limited memory. This adds a batch_size parameter to CNN() so users can reduce it when needed. Also fixes a few related problems found along the way: - features are now moved to cpu per batch instead of accumulating on gpu - collate_fn no longer crashes when all images in a batch are unreadable - bad_im_count now counts actual bad images instead of batches - pool in parallelise uses a context manager to avoid leaked workers - single image encoding uses explicit shape check instead of squeeze() --- imagededup/methods/cnn.py | 44 +++++++++--------- imagededup/utils/data_generator.py | 4 ++ imagededup/utils/general_utils.py | 10 ++--- tests/test_cnn.py | 71 ++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 26 deletions(-) diff --git a/imagededup/methods/cnn.py b/imagededup/methods/cnn.py index 1f88a85e..e7d7b1da 100644 --- a/imagededup/methods/cnn.py +++ b/imagededup/methods/cnn.py @@ -44,15 +44,17 @@ class CNN: def __init__( self, verbose: bool = True, - model_config: Optional[CustomModel] = None + model_config: Optional[CustomModel] = None, + batch_size: int = 64 ) -> None: """ Initialize a pytorch MobileNet model v3 that is sliced at the last convolutional layer. - Set the batch size for pytorch dataloader to be 64 samples. + Set the batch size for pytorch dataloader to be 64 samples by default. Args: verbose: Display progress bar if True else disable it. Default value is True. model_config: A CustomModel that can be used to initialize a custom PyTorch model along with the corresponding transform. + batch_size: Batch size for the dataloader during encoding generation. Lower values use less GPU memory. Default value is 64. """ self.model_config = model_config if model_config is not None else CustomModel( model=MobilenetV3(), transform=MobilenetV3.transform, name=MobilenetV3.name @@ -64,7 +66,9 @@ def __init__( ) # The logger needs to be bound to the class, otherwise stderr also gets # directed to stdout (Don't know why that is the case) - self.batch_size = 64 + if not isinstance(batch_size, int) or batch_size < 1: + raise ValueError('batch_size must be a positive integer') + self.batch_size = batch_size self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.logger.info(f"Device set to {self.device} ..") @@ -109,12 +113,7 @@ def _get_cnn_features_single(self, image_array: np.ndarray) -> np.ndarray: image_pp = image_pp.unsqueeze(0) img_features_tensor = self.model(image_pp.to(self.device)) - if self.device.type == "cuda": - unpacked_img_features_tensor = img_features_tensor.cpu().detach().numpy() - else: - unpacked_img_features_tensor = img_features_tensor.detach().numpy() - - return unpacked_img_features_tensor + return img_features_tensor.cpu().detach().numpy() def _get_cnn_features_batch( self, @@ -146,31 +145,36 @@ def _get_cnn_features_batch( with torch.no_grad(): for ims, filenames, bad_images in self.dataloader: + if ims is None or len(ims) == 0: + bad_im_count += len(bad_images) + continue arr = self.model(ims.to(self.device)) - feat_arr.extend(arr) + feat_arr.append(arr.cpu().detach().numpy()) + del arr + if self.device.type == 'cuda': + torch.cuda.empty_cache() all_filenames.extend(filenames) - if bad_images: - bad_im_count += 1 + bad_im_count += len(bad_images) if bad_im_count: self.logger.info( f"Found {bad_im_count} bad images, ignoring for encoding generation .." ) - feat_vec = torch.stack(feat_arr).squeeze() - feat_vec = ( - feat_vec.detach().numpy() - if self.device.type == "cpu" - else feat_vec.detach().cpu().numpy() - ) + if not feat_arr: + self.logger.info('No valid images found for encoding generation ..') + self.encoding_map = {} + return self.encoding_map + + feat_vec = np.vstack(feat_arr) valid_image_files = [filename for filename in all_filenames if filename] self.logger.info("End: Image encoding generation") filenames = generate_relative_names(image_dir, valid_image_files) if ( - len(feat_vec.shape) == 1 + feat_vec.shape[0] == 1 ): # can happen when encode_images is called on a directory containing a single image - self.encoding_map = {filenames[0]: feat_vec} + self.encoding_map = {filenames[0]: feat_vec[0]} else: self.encoding_map = {j: feat_vec[i, :] for i, j in enumerate(filenames)} return self.encoding_map diff --git a/imagededup/utils/data_generator.py b/imagededup/utils/data_generator.py index c2abb8a3..2f7bae45 100644 --- a/imagededup/utils/data_generator.py +++ b/imagededup/utils/data_generator.py @@ -46,6 +46,10 @@ def _collate_fn(batch: List[Dict]) -> Tuple[torch.tensor, str, str]: filenames.append(b['filename']) else: bad_images.append(b['filename']) + + if not ims: + return None, filenames, bad_images + return torch.stack(ims), filenames, bad_images diff --git a/imagededup/utils/general_utils.py b/imagededup/utils/general_utils.py index f9825072..2148e403 100644 --- a/imagededup/utils/general_utils.py +++ b/imagededup/utils/general_utils.py @@ -62,12 +62,10 @@ def save_json(results: Dict, filename: str, float_scores: bool = False) -> None: def parallelise(function: Callable, data: List, verbose: bool, num_workers: int) -> List: num_workers = 1 if num_workers < 1 else num_workers # Pool needs to have at least 1 worker. - pool = Pool(processes=num_workers) - results = list( - tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose) - ) - pool.close() - pool.join() + with Pool(processes=num_workers) as pool: + results = list( + tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose) + ) return results diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e36f2b62..30fd931c 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -88,6 +88,25 @@ def test__init_defaults(cnn): assert cnn.model_config.name == MobilenetV3.name +def test__init_custom_batch_size(): + cnn = CNN(batch_size=32) + assert cnn.batch_size == 32 + + cnn = CNN(batch_size=1) + assert cnn.batch_size == 1 + + +def test__init_invalid_batch_size(): + with pytest.raises(ValueError): + CNN(batch_size=0) + + with pytest.raises(ValueError): + CNN(batch_size=-1) + + with pytest.raises(ValueError): + CNN(batch_size='abc') + + def test__init_custom(): cnn = CNN(model_config=CustomModel(model=EfficientNet(), transform=EfficientNet.transform, @@ -983,6 +1002,58 @@ def test_find_duplicates_to_remove_encoding_integration(cnn): ) +# batch_size + + +def test_small_batch_size_produces_same_results(cnn): + cnn_small = CNN(batch_size=2) + encodings_default = cnn.encode_images(TEST_IMAGE_DIR) + encodings_small = cnn_small.encode_images(TEST_IMAGE_DIR) + + assert set(encodings_default.keys()) == set(encodings_small.keys()) + for k in encodings_default: + np.testing.assert_allclose(encodings_default[k], encodings_small[k], atol=1e-5) + + +def test_batch_size_one_produces_same_results(cnn): + cnn_one = CNN(batch_size=1) + encodings_default = cnn.encode_images(TEST_IMAGE_DIR) + encodings_one = cnn_one.encode_images(TEST_IMAGE_DIR) + + assert set(encodings_default.keys()) == set(encodings_one.keys()) + for k in encodings_default: + np.testing.assert_allclose(encodings_default[k], encodings_one[k], atol=1e-5) + + +def test_batch_size_larger_than_dataset(): + cnn_large = CNN(batch_size=128) + encodings = cnn_large.encode_images(TEST_IMAGE_DIR) + assert len(encodings) == 10 + + +def test_small_batch_size_find_duplicates_integration(): + cnn_small = CNN(batch_size=2) + duplicates = cnn_small.find_duplicates( + image_dir=TEST_IMAGE_DIR_MIXED, + min_similarity_threshold=0.9, + scores=False, + ) + assert 'ukbench00120.jpg' in duplicates + assert len(duplicates['ukbench00120.jpg']) > 0 + assert len(duplicates['ukbench09268.jpg']) == 0 + + +def test_all_bad_images_returns_empty_encoding(tmp_path): + bad_file = tmp_path / 'corrupt.jpg' + bad_file.write_bytes(b'not an image') + bad_file2 = tmp_path / 'corrupt2.jpg' + bad_file2.write_bytes(b'also not an image') + + cnn_inst = CNN() + result = cnn_inst.encode_images(tmp_path) + assert result == {} + + def test_scores_saving(cnn): save_file = 'myduplicates.json' cnn.find_duplicates(