diff --git a/miceforest/imputed_data.py b/miceforest/imputed_data.py index 4e59341..dcae1fc 100644 --- a/miceforest/imputed_data.py +++ b/miceforest/imputed_data.py @@ -31,15 +31,10 @@ def __init__( self.working_data.index, RangeIndex ), "Please reset the index on the dataframe" - column_names = [] - pd_dtypes_orig = {} - for col, series in self.working_data.items(): - assert isinstance(col, str), "column names must be strings" - assert ( - series.dtype.name != "object" - ), "convert object dtypes to something else" - column_names.append(col) - pd_dtypes_orig[col] = series.dtype.name + column_names = self.working_data.columns + assert np.all( + [isinstance(col, str) for col in column_names] + ), "Column names must be strings" self.column_names = column_names pd_dtypes_orig = self.working_data.dtypes @@ -91,6 +86,18 @@ def __init__( col for col in self.modeled_variables if col in self.vars_with_any_missing ] + # This should be all variables in the schema, not all variables in the dataset. + self.all_var_in_schema = set( + self.modeled_variables + + [y for x in self.variable_schema.values() for y in x] + ) + + for col in self.all_var_in_schema: + assert pd_dtypes_orig[col].name != "object", ( + "Cannot model an object column, please convert to int or categorical, or " + "specify a variable_schema that does not use the object column." + ) + if random_seed_array is not None: assert isinstance(random_seed_array, np.ndarray) assert ( diff --git a/pyproject.toml b/pyproject.toml index 19ae41f..2902c39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "miceforest" license = "MIT" -version = "6.0.4" +version = "6.0.5" description = "Multiple Imputation by Chained Equations with LightGBM" authors = [{name="Sam Von Wilson"}] readme = "README.md" diff --git a/tests/test_ImputationKernel.py b/tests/test_ImputationKernel.py index 609fbe1..7a21ea8 100644 --- a/tests/test_ImputationKernel.py +++ b/tests/test_ImputationKernel.py @@ -368,3 +368,36 @@ def test_complex(): mean_match_strategy=mixed_mms, save_all_iterations_data=True, ) + + +def test_object_column(): + + # Customize everything. + vs = { + "sl": ["ws", "pl", "pw", "sp", "bi"], + "ws": ["sl"], + "pl": ["sp", "bi"], + # 'sp': ['sl', 'ws', 'pl', 'pw', 'bc'], # Purposely don't train a variable that does have missing values + "pw": ["sl", "ws", "pl", "sp", "bi"], + "bi": ["ws", "pl", "sp"], + "ui8": ["sp", "ws"], + } + mmc = {"sl": 4, "ws": 0, "bi": 5} + ds = {"sl": int(iris_amp.shape[0] / 2), "ws": 50} + + iris_amp["obj_col"] = iris_amp["sl"].astype("object") + + imputed_var_names = list(vs) + non_imputed_var_names = [c for c in iris_amp if c not in imputed_var_names] + kernel = mf.ImputationKernel( + data=iris_amp, + num_datasets=2, + variable_schema=vs, + mean_match_candidates=mmc, + data_subset=ds, + mean_match_strategy="normal", + save_all_iterations_data=True, + ) + + assert "obj_col" not in kernel.variable_schema + assert "obj_col" not in kernel.all_var_in_schema