PyTorch: trying to create a joint dataset with different transforms results in both datasets having same transform

Summary

The core issue stems from a subtle naming conflict and shared mutable state in the user’s PyTorch data pipeline. Specifically, the variable names train_set and val_set are first assigned to the unmasked dataset subsets. However, later in the script, the variables train_set and val_set are accidentally reused when creating the masked dataset subsets. This overwrites the original unmasked references, causing the JointDataset to receive the same masked data for both inputs, resulting in identical outputs. Additionally, the global mutable list self.data in the MyDataset class is shared across instances, leading to potential data corruption if the mask operation modifies it in place.

Root Cause

  1. Variable Name Shadowing: In the provided code, train_set and val_set are initially defined as Subset objects of the original train_dataset. Later, the code reassigns these exact same variable names to new Subset objects of the masked dataset:

    # Original assignment
    train_set, val_set = data.random_split(train_dataset, [80, 20]...)
    
    # Overwriting assignment (suspected intent was train_masked_set, val_masked_set)
    train_set = data.Subset(train_masked_dataset, train_indices)
    val_set = data.Subset(train_masked_dataset, val_indices)

    Consequently, JointDataset(train_set, train_masked_set) actually passes the masked dataset as both arguments.

  2. In-Place Mutation in __getitem__: The add_masking function modifies the sample array in place (noisy_sample[i] = -1.0). Since sample = self.data[:,idx] returns a view of the underlying numpy array (not a copy), this operation permanently modifies the original dataset’s data storage.

    sample = self.data[:,idx] # This is a view, not a copy
    if self.transform:
        sample = self.transform(sample) # Transform modifies the view
  3. Shared Mutable State: The MyDataset class stores self.data as a class attribute. In Python, numpy arrays are mutable. If multiple instances of MyDataset are created referencing the same array (as intended for unmasked vs masked), the in-place masking operation in one instance corrupts the data for the other instance.

Why This Happens in Real Systems

This type of error is common in data engineering pipelines and research codebases for several reasons:

  • Exploratory Programming: When iterating quickly on data augmentation strategies, engineers often copy-paste code blocks and forget to update variable names (e.g., changing train_dataset to train_masked_dataset but keeping the variable assignment train_set).
  • Implicit Dependencies: Relying on global mutable state (like the shared numpy array in MyDataset) creates hidden dependencies between dataset instances. A change in one instance (masking) unexpectedly affects another.
  • Lack of Defensive Copying: Failing to copy data in __getitem__ violates the principle of encapsulation. Dataset classes should return independent copies (or immutable views) of data to prevent side effects between DataLoader workers or different pipeline stages.

Real-World Impact

  • Silent Data Corruption: The model trains on identical unmasked and masked inputs. The loss function receives two identical tensors, leading to mathematically valid but logically incorrect training. The model never learns to reconstruct or process masked data correctly.
  • Unreliable Validation: Validation metrics become meaningless because the ground truth (unmasked) and the input (masked) are the same, masking the model’s inability to handle missing data.
  • Debugging Difficulty: The bug does not cause immediate crashes. It manifests as poor model performance or convergence to trivial solutions (e.g., the identity function), requiring extensive debugging to locate the root cause in the data pipeline.

Example or Code

To demonstrate the fix, we must ensure:

  1. Distinct variable names for unmasked and masked subsets.
  2. Defensive copying to prevent in-place mutation of the source data.
    import torch
    import torch.utils.data as data
    import numpy as np
    from torchvision import transforms

def add_masking(sample):

Create a copy to avoid modifying the original data array

noisy_sample = sample.copy()
prob = 0.25
for i in range(sample.shape[0]):
    if np.random.uniform(0.0, 1.0) < prob:
        noisy_sample[i] = -1.0
return noisy_sample

class MyDataset(data.Dataset):
def init(self, array: np.array, transform: Callable = None):
self.data = array
self.transform = transform

def __len__(self):
    return self.data.shape[1]

def __getitem__(self, idx):
    # Return a copy to ensure modifications don't affect the source
    sample = self.data[:, idx].copy()
    if self.transform:
        sample = self.transform(sample)
    return sample

Generate dataset

num_examples = 100
num_features = 10
X = np.random.rand(num_features, num_examples)

Create base dataset

base_dataset = MyDataset(array=X, transform=None)

Split

train_set, val_set = data.random_split(base_dataset, [80, 20], generator=torch.Generator().manual_seed(42))
train_indices = train_set.indices
val_indices = val_set.indices

Create masked dataset with transform

masked_dataset = MyDataset(array=X, transform=transforms.Lambda(lambda x: add_masking(x)))

Create subsets for masked data using UNIQUE variable names

train_masked_set = data.Subset(masked_dataset, train_indices)
val_masked_set = data.Subset(masked_dataset, val_indices)

class JointDataset(data.Dataset):
def init(self, dataset1, dataset2):
self.dataset1 = dataset1
self.dataset2 = dataset2
assert len(self.dataset1) == len(self.dataset2)

def __len__(self):
    return len(self.dataset1)

def __getitem__(self, index):
    data1 = self.dataset1[index]
    data2 = self.dataset2[index]
    return data1, data2

Combine unmasked and masked data correctly

train_set_combined = JointDataset(train_set, train_masked_set)
val_set_combined = JointDataset(val_set, val_masked_set)

def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)

train_loader = data.DataLoader(train_set_combined, batch_size=16, shuffle=True, drop_last=True, pin_memory=True, num_workers=0, collate_fn=numpy_collate)

data_iter = iter(train_loader)
batch_unmasked, batch_masked = next(data_iter)

Verification (mock print statement for logic validation)

print(batch_unmasked[0:2,:].T)

print(batch_masked[0:2,:].T)

## How Senior Engineers Fix It

1.  **Enforce Immutability**: The primary fix is ensuring `__getitem__` returns a copy of the data, not a view. This isolates the dataset instance from side effects of transforms.
    *   *Fix*: Use `.copy()` in `__getitem__` or ensure transforms create new tensors.
2.  **Sanitize Variable Scopes**: Avoid reusing variable names for different purposes. Senior engineers use distinct, descriptive names (e.g., `train_subset_raw` vs `train_subset_masked`) to prevent shadowing errors.
3.  **Add Validation Hooks**: Before passing datasets to the `DataLoader`, implement a simple check to verify data integrity. For example, iterate over the first few samples of the `JointDataset` and assert that `data1` is not equal to `data2` (assuming a high probability of masking).
4.  **Isolate State**: Ensure that dataset classes do not rely on mutable global state. Each dataset instance should own its data or hold immutable references.

## Why Juniors Miss It

*   **Focus on Logic, Not State**: Junior developers often focus on the algorithmic logic of the transform (the `add_masking` function) but overlook the implications of **how data is stored and referenced** in memory (views vs. copies).
*   **Visual Inspection Blindness**: When debugging, juniors often look at the logic flow but miss the variable names. They might see `train_set = ...` and assume it refers to the original set, not noticing it has been reassigned.
*   **Underestimating In-Place Operations**: The concept that `sample = self.data[:, idx]` creates a mutable view rather than a safe copy is a common trap. Beginners often assume assigning to a variable creates a new object, not realizing numpy arrays are passed by reference.