giagrad.nn.Module.apply#

Module.apply(fn: Callable)[source]#

Applies fn recursively to every submodule as well as self.

Typical use includes initializing the parameters of a model.

Parameters:

fn (callable, Module -> None) – Function to be applied to each submodule, whether it is a Tensor or a Module fn must modify them in-place.

Examples

Using class Model defined in Module example.

>>> def init_weights(m):
...     if isinstance(m, nn.Linear):
...         m.w.ones()
>>> m.apply(init_weights)
>>> np.all(m.l1.w.data == 1) and np.all(m.l2.w.data == 1)
True