Sparse Models

class weavenet.sparse.model.TrainableMatchingModuleSp(net_sparse, pre_interactor_sparse=CrossConcat(), mask_selector=MaskSelectorByLinearInferenceOr(), output_channels=1, net_dense=None, pre_interactor_dense=CrossConcat(), stream_aggregator=DualSoftmaxSqrtSp())[source]

A variant of TrainableMatchingModule that treats sparse bipartite graph.

Parameters
  • net – a sparse GNN net that estimate matching.

  • mask_estimator – an algorithm to estimate mask from input of forward() or from the output of pre_net.

  • output_channels (int) – see TrainableMatchingModule.

  • pre_interactorTrainableMatchingModule.

  • pre_net – a GNN pre_net applied before mask_estimator.

  • stream_aggregator (Optional[Callable[[Tensor, Tensor, Tensor, Optional[Tensor]], Tuple[Tensor, Tensor, Tensor]]]) – the aggregator that merges estimation from all the streams (Default: DualSoftMaxSp)

  • net_sparse (Module) –

  • pre_interactor_sparse (Optional[CrossConcat]) –

  • mask_selector (Module) –

  • net_dense (Optional[Module]) –

  • pre_interactor_dense (Optional[CrossConcat]) –

Example of Usage (1) Use WeaveNet to solve stable matching or any other combinatorial optimization:

from weavenet import TrainableMatchingModule, WeaveNet

weave_net_sp = TrainableMatchingModule(
    net = WeaveNetSp(2*32, [64]*3, [32]*3),
    pre_net = WeaveNet(2, [64]*3,  [32]*3),
)

for xab, xba_t in batches:
    y_pred = weave_net_sp(xab, xba_t)
    loss = calc_loss(y_pred, y_true)
    ...

**Example of Usage (2) Use WeaveNet for matching extracted features **:

from sparse.weavenet import TrainableMatchingModuleSp, WeaveNetSp
from weavenet import WeaveNet
from layers import CrossConcatVertexFeatures

weave_net_sp = TrainableMatchingModule(
    net = WeaveNetSp(2*32, [64]*3, [32]*3),
    pre_net = WeaveNet(2*vfeature_channels+2, [64]*3,  [32]*3),
    pre_interactor = CrossConcatVertexFeatures(compute_similarity_cosine, softmax)
)

for xa, xb, y_true in batches:
    xa = feature_extractor(xa)
    xb = feature_extractor(xb)
    y_pred = weave_net_sp(xa, xb)
    loss = calc_loss(y_pred, y_true)
    ...
forward(xab, xba_t)[source]

Try to match a bipartite agents on side a and b.

Shape:
  • xab: \((\ldots, N, M, C)\)

  • xba_t: \((\ldots, N, M, C)\)

  • output: \((\ldots, N, M, \text{output_channels}')\) if dim_a = -3 and dim_b = -2. \(C' = 2*C\) if compute_similarity is None, \(C' = 2*C+1\) if only compute_similarity is set, and :math:`C’ = 2*C+2`if both compute_similarity and directional_normalization are set.

Parameters
  • xab (Tensor) – vertex features on the side a or edge features directed from the side a to b.

  • xba_t (Tensor) – vertex features on the side b or edge features directed from the side b to a.

Returns

A resultant tensor aggregated by stream_aggregator after processed through the network.

Return type

Tuple[Tensor, Tensor, Tensor]

Experimental

class weavenet.sparse.model.ExperimentalUnitListGeneratorSp(input_channels, output_channels_list)[source]

Sparse version of ExperimentalUnitListGenerator

Parameters
  • input_channels (int) – input_channels for the first unit.

  • mid_channels_list – mid_channels for each point-net-based set encoders.

  • output_channels_list (List[int]) – output_channels for the units.

class Encoder(in_channels, mid_channels, output_channels, **kwargs)[source]
Parameters
  • in_channels (int) –

  • mid_channels (int) –

  • output_channels (int) –

Private Models

class weavenet.sparse.model.UnitSp(encoder, order, normalizer=None, activator=None)[source]

a sparse version of Unit.

Parameters
  • encoder (Module) – a trainable unit

  • order (typing_extensions.Literal[ena, nae, ean, ane]) – a direction of process order. [‘ena’|’nae’|’ean’|’ane’], e: encoder, a: activator, n: normalizer. e.g.) ‘ena’ applies encoder->normalizer->activator.

  • normalizer (Optional[Module]) – a normalizer, such as nn.BatchNorm1d.

  • activator (Optional[Module]) – an activation function, such as nn.PReLU.

forward(x, vertex_id)[source]

Applies unit process. This function is replaced to any of Unit._forward_* functions in __init__() based on the argument order.

Shape:
  • x: \((\text{num_edges_in_batch}, C)\)

  • vertex_id: \((\text{num_edges_in_batch})\)

Parameters
  • x (Tensor) – a edge features, which are flatten through a batch.

  • vertex_id (Tensor) – ID list of each edge (ID identifies src or target vetex of each edge).

Returns

A processed features.

Return type

Tensor

class weavenet.sparse.model.WeaveNetUnitListGeneratorSp(input_channels, mid_channels_list, output_channels_list)[source]

Sparse version of WeaveNetUnitListGenerator

Parameters
  • input_channels (int) – input_channels for the first unit.

  • mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.

  • output_channels_list (List[int]) – output_channels for the units.