diff --git a/imagededup/methods/cnn.py b/imagededup/methods/cnn.py index 1f88a85e..e7d7b1da 100644 --- a/imagededup/methods/cnn.py +++ b/imagededup/methods/cnn.py @@ -44,15 +44,17 @@ class CNN: def __init__( self, verbose: bool = True, - model_config: Optional[CustomModel] = None + model_config: Optional[CustomModel] = None, + batch_size: int = 64 ) -> None: """ Initialize a pytorch MobileNet model v3 that is sliced at the last convolutional layer. - Set the batch size for pytorch dataloader to be 64 samples. + Set the batch size for pytorch dataloader to be 64 samples by default. Args: verbose: Display progress bar if True else disable it. Default value is True. model_config: A CustomModel that can be used to initialize a custom PyTorch model along with the corresponding transform. + batch_size: Batch size for the dataloader during encoding generation. Lower values use less GPU memory. Default value is 64. """ self.model_config = model_config if model_config is not None else CustomModel( model=MobilenetV3(), transform=MobilenetV3.transform, name=MobilenetV3.name @@ -64,7 +66,9 @@ def __init__( ) # The logger needs to be bound to the class, otherwise stderr also gets # directed to stdout (Don't know why that is the case) - self.batch_size = 64 + if not isinstance(batch_size, int) or batch_size < 1: + raise ValueError('batch_size must be a positive integer') + self.batch_size = batch_size self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.logger.info(f"Device set to {self.device} ..") @@ -109,12 +113,7 @@ def _get_cnn_features_single(self, image_array: np.ndarray) -> np.ndarray: image_pp = image_pp.unsqueeze(0) img_features_tensor = self.model(image_pp.to(self.device)) - if self.device.type == "cuda": - unpacked_img_features_tensor = img_features_tensor.cpu().detach().numpy() - else: - unpacked_img_features_tensor = img_features_tensor.detach().numpy() - - return unpacked_img_features_tensor + return img_features_tensor.cpu().detach().numpy() def _get_cnn_features_batch( self, @@ -146,31 +145,36 @@ def _get_cnn_features_batch( with torch.no_grad(): for ims, filenames, bad_images in self.dataloader: + if ims is None or len(ims) == 0: + bad_im_count += len(bad_images) + continue arr = self.model(ims.to(self.device)) - feat_arr.extend(arr) + feat_arr.append(arr.cpu().detach().numpy()) + del arr + if self.device.type == 'cuda': + torch.cuda.empty_cache() all_filenames.extend(filenames) - if bad_images: - bad_im_count += 1 + bad_im_count += len(bad_images) if bad_im_count: self.logger.info( f"Found {bad_im_count} bad images, ignoring for encoding generation .." ) - feat_vec = torch.stack(feat_arr).squeeze() - feat_vec = ( - feat_vec.detach().numpy() - if self.device.type == "cpu" - else feat_vec.detach().cpu().numpy() - ) + if not feat_arr: + self.logger.info('No valid images found for encoding generation ..') + self.encoding_map = {} + return self.encoding_map + + feat_vec = np.vstack(feat_arr) valid_image_files = [filename for filename in all_filenames if filename] self.logger.info("End: Image encoding generation") filenames = generate_relative_names(image_dir, valid_image_files) if ( - len(feat_vec.shape) == 1 + feat_vec.shape[0] == 1 ): # can happen when encode_images is called on a directory containing a single image - self.encoding_map = {filenames[0]: feat_vec} + self.encoding_map = {filenames[0]: feat_vec[0]} else: self.encoding_map = {j: feat_vec[i, :] for i, j in enumerate(filenames)} return self.encoding_map diff --git a/imagededup/utils/data_generator.py b/imagededup/utils/data_generator.py index c2abb8a3..2f7bae45 100644 --- a/imagededup/utils/data_generator.py +++ b/imagededup/utils/data_generator.py @@ -46,6 +46,10 @@ def _collate_fn(batch: List[Dict]) -> Tuple[torch.tensor, str, str]: filenames.append(b['filename']) else: bad_images.append(b['filename']) + + if not ims: + return None, filenames, bad_images + return torch.stack(ims), filenames, bad_images diff --git a/imagededup/utils/general_utils.py b/imagededup/utils/general_utils.py index f9825072..2148e403 100644 --- a/imagededup/utils/general_utils.py +++ b/imagededup/utils/general_utils.py @@ -62,12 +62,10 @@ def save_json(results: Dict, filename: str, float_scores: bool = False) -> None: def parallelise(function: Callable, data: List, verbose: bool, num_workers: int) -> List: num_workers = 1 if num_workers < 1 else num_workers # Pool needs to have at least 1 worker. - pool = Pool(processes=num_workers) - results = list( - tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose) - ) - pool.close() - pool.join() + with Pool(processes=num_workers) as pool: + results = list( + tqdm.tqdm(pool.imap(function, data, 100), total=len(data), disable=not verbose) + ) return results diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e36f2b62..30fd931c 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -88,6 +88,25 @@ def test__init_defaults(cnn): assert cnn.model_config.name == MobilenetV3.name +def test__init_custom_batch_size(): + cnn = CNN(batch_size=32) + assert cnn.batch_size == 32 + + cnn = CNN(batch_size=1) + assert cnn.batch_size == 1 + + +def test__init_invalid_batch_size(): + with pytest.raises(ValueError): + CNN(batch_size=0) + + with pytest.raises(ValueError): + CNN(batch_size=-1) + + with pytest.raises(ValueError): + CNN(batch_size='abc') + + def test__init_custom(): cnn = CNN(model_config=CustomModel(model=EfficientNet(), transform=EfficientNet.transform, @@ -983,6 +1002,58 @@ def test_find_duplicates_to_remove_encoding_integration(cnn): ) +# batch_size + + +def test_small_batch_size_produces_same_results(cnn): + cnn_small = CNN(batch_size=2) + encodings_default = cnn.encode_images(TEST_IMAGE_DIR) + encodings_small = cnn_small.encode_images(TEST_IMAGE_DIR) + + assert set(encodings_default.keys()) == set(encodings_small.keys()) + for k in encodings_default: + np.testing.assert_allclose(encodings_default[k], encodings_small[k], atol=1e-5) + + +def test_batch_size_one_produces_same_results(cnn): + cnn_one = CNN(batch_size=1) + encodings_default = cnn.encode_images(TEST_IMAGE_DIR) + encodings_one = cnn_one.encode_images(TEST_IMAGE_DIR) + + assert set(encodings_default.keys()) == set(encodings_one.keys()) + for k in encodings_default: + np.testing.assert_allclose(encodings_default[k], encodings_one[k], atol=1e-5) + + +def test_batch_size_larger_than_dataset(): + cnn_large = CNN(batch_size=128) + encodings = cnn_large.encode_images(TEST_IMAGE_DIR) + assert len(encodings) == 10 + + +def test_small_batch_size_find_duplicates_integration(): + cnn_small = CNN(batch_size=2) + duplicates = cnn_small.find_duplicates( + image_dir=TEST_IMAGE_DIR_MIXED, + min_similarity_threshold=0.9, + scores=False, + ) + assert 'ukbench00120.jpg' in duplicates + assert len(duplicates['ukbench00120.jpg']) > 0 + assert len(duplicates['ukbench09268.jpg']) == 0 + + +def test_all_bad_images_returns_empty_encoding(tmp_path): + bad_file = tmp_path / 'corrupt.jpg' + bad_file.write_bytes(b'not an image') + bad_file2 = tmp_path / 'corrupt2.jpg' + bad_file2.write_bytes(b'also not an image') + + cnn_inst = CNN() + result = cnn_inst.encode_images(tmp_path) + assert result == {} + + def test_scores_saving(cnn): save_file = 'myduplicates.json' cnn.find_duplicates(