Models

class weavenet.model.TrainableMatchingModule(net, output_channels=1, pre_interactor=CrossConcat(), stream_aggregator=DualSoftmaxSqrt(   (sm_col): Softmax(dim=-2)   (sm_row): Softmax(dim=-3) ))[source]

wrap a GNN head to solve various matching problems.

Parameters
  • head – a GNN head that estimate matching.

  • output_channels (int) – a number of matching results (Default: 1).

  • pre_interactor (Optional[CrossConcat]) – the interactor that first merge two input at the forward function (Default: CrossConcat).

  • stream_aggregator (Optional[Callable[[Tensor, Optional[Tensor], bool], Tuple[Tensor, Tensor, Tensor]]]) – the aggregator that merges estimation from all the streams (Default: DualSoftmaxSqrt)

  • net (Module) –

Example of Usage (1) Use WeaveNet to solve stable matching or any other combinatorial optimization:

from weavenet import TrainableMatchingModule, WeaveNet

weave_net = TrainableMatchingModule(
    head = WeaveNet(2, [64]*6, [32]*6),
)

for xab, xba_t in batches:
    y_pred = weave_net(xab, xba_t)
    loss = calc_loss(y_pred, y_true)
    ...

**Example of Usage (2) Use WeaveNet for matching extracted features **:

from weavenet import TrainableMatchingModule, WeaveNet
from layers import CrossConcatVertexFeatures

weave_net = TrainableMatchingModule(
    head = WeaveNet(2*vfeature_channels+2, [64]*6, [32]*6),
    pre_interactor = CrossConcatVertexFeatures(compute_similarity_cosine, softmax)
)

for xa, xb, y_true in batches:
    xa = feature_extractor(xa)
    xb = feature_extractor(xb)
    y_pred = weave_net(xa, xb)
    loss = calc_loss(y_pred, y_true)
    ...
forward(xab, xba_t)[source]

Try to match a bipartite agents on side a and b.

Shape:
  • xab: \((\ldots, N, M, C)\)

  • xba_t: \((\ldots, N, M, C)\)

  • output: \((\ldots, N, M, \text{output_channels}')\) if dim_a = -3 and dim_b = -2. \(C' = 2*C\) if compute_similarity is None, \(C' = 2*C+1\) if only compute_similarity is set, and :math:`C’ = 2*C+2`if both compute_similarity and directional_normalization are set.

Parameters
  • xab (Tensor) – vertex features on the side a or edge features directed from the side a to b.

  • xba_t (Tensor) – vertex features on the side b or edge features directed from the side b to a.

Returns

A resultant tensor aggregated by stream_aggregator after processed through the network.

Return type

Tuple[Tensor, Tensor, Tensor]

Experimental

class weavenet.model.ExperimentalUnitListGenerator(input_channels, mid_channels_list, output_channels_list)[source]

A factory of experimental units. This is a sample class for user custom units.

Parameters
  • input_channels (int) – input_channels for the first unit.

  • mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.

  • output_channels_list (List[int]) – output_channels for the units.

class Encoder(in_channels, mid_channels, output_channels, **kwargs)[source]

A sample of experimental unit encoder.

Parameters
  • in_channels (int) – input_channels for the first unit.

  • mid_channels_list – mid_channels for each point-net-based set encoders.

  • output_channels_list – output_channels for the units.

  • mid_channels (int) –

  • output_channels (int) –

Private Models

class weavenet.model.Unit(encoder, order, normalizer=None, activator=None)[source]

Applies a series of process with encoder, normalizer, and activator in the directed order.

Parameters
  • encoder (Module) – a trainable unit

  • order (typing_extensions.Literal[ena, nae, ean, ane]) – a direction of process order. [‘ena’|’nae’|’ean’|’ane’], e: encoder, a: activator, n: normalizer. e.g.) ‘ena’ applies encoder->normalizer->activator.

  • normalizer (Optional[Module]) – a normalizer, such as BatchNormXXC.

  • activator (Optional[Module]) – an activation function, such as nn.PReLU.

forward(x, dim_target)[source]

Applies unit process. This function is replaced to any of Unit._forward_* functions in __init__() based on the argument order.

Shape:
  • x: \((\ldots, C)\)

Parameters
  • x (Tensor) – a source features.

  • dim_target (int) – dimention of target vertex.

Returns

A processed features.

Return type

Tensor

class weavenet.model.UnitListGenerator(input_channels, output_channels_list)[source]

A factory of units.

Parameters
  • input_channels (int) – input_channels for the first unit.

  • output_channels_list (List[int]) – output_channels for the units.

generate(interactor=None)[source]

Generates a list of units, assuming the interactor at each end of unit-process.

Parameters

interactor (Optional[Interactor]) – a concrete class of Interactor. Typically, CrossConcat. If None, assumes no interaction at each end of unit-process.

Returns

a list of units.

Return type

List[Unit]

class weavenet.model.WeaveNetUnitListGenerator(input_channels, mid_channels_list, output_channels_list)[source]

A factory of weavenet units.

Parameters
  • input_channels (int) – input_channels for the first unit.

  • mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.

  • output_channels_list (List[int]) – output_channels for the units.