Is it possible to call func1
with the DataParallel
functionality retained?
I have a class A
, which defines all my networks, and I am wrapping it with torch.nn.DataParallel
. When I call the forward function, a()
, it works fine. However, when I call other functions of A
, such as func1
, while still retaining the DataParallel
functionality, it does not work.
Minimum Non-Working Example (Just to convey the context better):
class A(torch.nn.module)
def __init__():
blah blah blah
def forward(some_arguments):
blah blah blah
def func1(some_arguments):
blah blah blah
a = A()
a = torch.nn.DataParallel(a, device_ids=[0, 1])
# calling forward function
outputs = a(inputs) # works fine.
# calling func1
outputs1 = a.func1(inputs) # does not work.
outputs1 = a.module.func1(inputs) # works without parallelizing data. I am not sure if this is the right thing to do
Can I call func1
with the DataParallel
functionality retained?