Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions miceforest/imputed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,10 @@ def __init__(
self.working_data.index, RangeIndex
), "Please reset the index on the dataframe"

column_names = []
pd_dtypes_orig = {}
for col, series in self.working_data.items():
assert isinstance(col, str), "column names must be strings"
assert (
series.dtype.name != "object"
), "convert object dtypes to something else"
column_names.append(col)
pd_dtypes_orig[col] = series.dtype.name
column_names = self.working_data.columns
assert np.all(
[isinstance(col, str) for col in column_names]
), "Column names must be strings"

self.column_names = column_names
pd_dtypes_orig = self.working_data.dtypes
Expand Down Expand Up @@ -91,6 +86,18 @@ def __init__(
col for col in self.modeled_variables if col in self.vars_with_any_missing
]

# This should be all variables in the schema, not all variables in the dataset.
self.all_var_in_schema = set(
self.modeled_variables
+ [y for x in self.variable_schema.values() for y in x]
)

for col in self.all_var_in_schema:
assert pd_dtypes_orig[col].name != "object", (
"Cannot model an object column, please convert to int or categorical, or "
"specify a variable_schema that does not use the object column."
)

if random_seed_array is not None:
assert isinstance(random_seed_array, np.ndarray)
assert (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "miceforest"
license = "MIT"
version = "6.0.4"
version = "6.0.5"
description = "Multiple Imputation by Chained Equations with LightGBM"
authors = [{name="Sam Von Wilson"}]
readme = "README.md"
Expand Down
33 changes: 33 additions & 0 deletions tests/test_ImputationKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,36 @@ def test_complex():
mean_match_strategy=mixed_mms,
save_all_iterations_data=True,
)


def test_object_column():

# Customize everything.
vs = {
"sl": ["ws", "pl", "pw", "sp", "bi"],
"ws": ["sl"],
"pl": ["sp", "bi"],
# 'sp': ['sl', 'ws', 'pl', 'pw', 'bc'], # Purposely don't train a variable that does have missing values
"pw": ["sl", "ws", "pl", "sp", "bi"],
"bi": ["ws", "pl", "sp"],
"ui8": ["sp", "ws"],
}
mmc = {"sl": 4, "ws": 0, "bi": 5}
ds = {"sl": int(iris_amp.shape[0] / 2), "ws": 50}

iris_amp["obj_col"] = iris_amp["sl"].astype("object")

imputed_var_names = list(vs)
non_imputed_var_names = [c for c in iris_amp if c not in imputed_var_names]
kernel = mf.ImputationKernel(
data=iris_amp,
num_datasets=2,
variable_schema=vs,
mean_match_candidates=mmc,
data_subset=ds,
mean_match_strategy="normal",
save_all_iterations_data=True,
)

assert "obj_col" not in kernel.variable_schema
assert "obj_col" not in kernel.all_var_in_schema