gromo.modules.linear_growing_module.LinearMergeGrowingModule#
- class gromo.modules.linear_growing_module.LinearMergeGrowingModule(post_merge_function: Module = Identity(), previous_modules: list[GrowingModule | MergeGrowingModule] | None = None, next_modules: list[GrowingModule | MergeGrowingModule] | None = None, allow_growing: bool = False, in_features: int | None = None, device: device | None = None, name: str | None = None)[source]#
Module to connect multiple linear modules with an merge operation. This module does not perform the merge operation, it is done by the user.
- Parameters:
post_merge_function (torch.nn.Module, optional) – activation function after the merge, by default torch.nn.Identity()
previous_modules (list[GrowingModule | MergeGrowingModule] | None, optional) – list of preceding modules, by default None
next_modules (list[GrowingModule | MergeGrowingModule] | None, optional) – list of succeeding modules, by default None
allow_growing (bool, optional) – allow growth of the module, by default False
in_features (int | None, optional) – input features, by default None
device (torch.device | None, optional) – default device, by default None
name (str | None, optional) – name of the module, by default None
- compute_previous_m_update() tuple[Tensor, int][source]#
Compute the update of the tensor M for the input of all previous modules. B: full activity tensor M = dLoss/dA^T B
- Returns:
torch.Tensor – update of the tensor M
int – number of samples used to compute the update
- compute_previous_s_update() tuple[Tensor, int][source]#
Compute the update of the tensor S for the input of all previous modules. B: full activity tensor S = B^T B
- Returns:
torch.Tensor – update of the tensor S
int – number of samples used to compute the update
- compute_s_update() Tensor[source]#
Compute the update of the tensor S. With the input tensor X, the update is U^{j k} = X^{i j} X^{i k}.
- Returns:
update of the tensor S
- Return type:
torch.Tensor
- construct_full_activity() Tensor[source]#
Construct the full activity tensor B from the input of all previous modules. B = (B_1, B_2, …, B_k) in (n, C1 + C2 + … + Ck) with Ck the number of features of the k-th module. With B_i = (X_i, 1) in (n, C_i’ + 1) if the bias is used.
- Returns:
full activity tensor
- Return type:
torch.Tensor
- property input_volume: int#
Expected input volume. For linear merge layers reduced to input features
- Returns:
input volume
- Return type:
- property out_features: int#
Output features. For linear merge layers reduced to input features
- Returns:
output features
- Return type:
- property output_volume: int#
Expected output volume. For linear merge layers reduced to input features
- Returns:
output volume
- Return type:
- set_next_modules(next_modules: list[MergeGrowingModule | GrowingModule]) None[source]#
Set the next modules of the current module.
- Parameters:
next_modules (list[MergeGrowingModule | GrowingModule]) – list of next modules
- set_previous_modules(previous_modules: list[MergeGrowingModule | GrowingModule]) None[source]#
Set the previous modules of the current module.
- Parameters:
previous_modules (list[MergeGrowingModule | GrowingModule]) – list of previous modules
- Raises:
TypeError – if the previous module is not of type LinearGrowingModule or MergeGrowingModule
ValueError – if the input features do not match the output volume of the previous modules