gromo.modules.linear_growing_module.LinearMergeGrowingModule#

class gromo.modules.linear_growing_module.LinearMergeGrowingModule(post_merge_function: Module = Identity(), previous_modules: list[GrowingModule | MergeGrowingModule] | None = None, next_modules: list[GrowingModule | MergeGrowingModule] | None = None, allow_growing: bool = False, in_features: int | None = None, device: device | None = None, name: str | None = None)[source]#

Module to connect multiple linear modules with an merge operation. This module does not perform the merge operation, it is done by the user.

Parameters:
  • post_merge_function (torch.nn.Module, optional) – activation function after the merge, by default torch.nn.Identity()

  • previous_modules (list[GrowingModule | MergeGrowingModule] | None, optional) – list of preceding modules, by default None

  • next_modules (list[GrowingModule | MergeGrowingModule] | None, optional) – list of succeeding modules, by default None

  • allow_growing (bool, optional) – allow growth of the module, by default False

  • in_features (int | None, optional) – input features, by default None

  • device (torch.device | None, optional) – default device, by default None

  • name (str | None, optional) – name of the module, by default None

compute_previous_m_update() tuple[Tensor, int][source]#

Compute the update of the tensor M for the input of all previous modules. B: full activity tensor M = dLoss/dA^T B

Returns:

  • torch.Tensor – update of the tensor M

  • int – number of samples used to compute the update

compute_previous_s_update() tuple[Tensor, int][source]#

Compute the update of the tensor S for the input of all previous modules. B: full activity tensor S = B^T B

Returns:

  • torch.Tensor – update of the tensor S

  • int – number of samples used to compute the update

compute_s_update() Tensor[source]#

Compute the update of the tensor S. With the input tensor X, the update is U^{j k} = X^{i j} X^{i k}.

Returns:

update of the tensor S

Return type:

torch.Tensor

construct_full_activity() Tensor[source]#

Construct the full activity tensor B from the input of all previous modules. B = (B_1, B_2, …, B_k) in (n, C1 + C2 + … + Ck) with Ck the number of features of the k-th module. With B_i = (X_i, 1) in (n, C_i’ + 1) if the bias is used.

Returns:

full activity tensor

Return type:

torch.Tensor

property input_volume: int#

Expected input volume. For linear merge layers reduced to input features

Returns:

input volume

Return type:

int

property out_features: int#

Output features. For linear merge layers reduced to input features

Returns:

output features

Return type:

int

property output_volume: int#

Expected output volume. For linear merge layers reduced to input features

Returns:

output volume

Return type:

int

set_next_modules(next_modules: list[MergeGrowingModule | GrowingModule]) None[source]#

Set the next modules of the current module.

Parameters:

next_modules (list[MergeGrowingModule | GrowingModule]) – list of next modules

set_previous_modules(previous_modules: list[MergeGrowingModule | GrowingModule]) None[source]#

Set the previous modules of the current module.

Parameters:

previous_modules (list[MergeGrowingModule | GrowingModule]) – list of previous modules

Raises:
  • TypeError – if the previous module is not of type LinearGrowingModule or MergeGrowingModule

  • ValueError – if the input features do not match the output volume of the previous modules