Calling torch.nn.module functions w/ DataParallel

Is it possible to call func1 with the DataParallel functionality retained?


I have a class A, which defines all my networks, and I am wrapping it with torch.nn.DataParallel. When I call the forward function, a(), it works fine. However, when I call other functions of A, such as func1, while still retaining the DataParallel functionality, it does not work.

Minimum Non-Working Example (Just to convey the context better):

class A(torch.nn.module)
    def __init__():
        blah blah blah
    
    def forward(some_arguments):
        blah blah blah

    def func1(some_arguments):
        blah blah blah

a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs)  # works fine.
# calling func1
outputs1 = a.func1(inputs)  # does not work.
outputs1 = a.module.func1(inputs)  # works without parallelizing data. I am not sure if this is the right thing to do

Can I call func1 with the DataParallel functionality retained?

No, you cannot directly call func1 while retaining the DataParallel functionality.

When you wrap your model a with torch.nn.DataParallel, it creates a parallelized version of the model that splits the input data across multiple devices and runs the forward pass in parallel. However, this parallelization only applies to the forward method.

To call func1 while still retaining the DataParallel functionality, you can access the original model using a.module and call func1 directly on the original model. However, note that this bypasses the parallelization, so it won’t take advantage of multiple devices.

Here’s an example:

outputs1 = a.module.func1(inputs)

But keep in mind that this will not utilize the parallel computing power of DataParallel.