Models¶
- class weavenet.model.TrainableMatchingModule(net, output_channels=1, pre_interactor=CrossConcat(), stream_aggregator=DualSoftmaxSqrt( (sm_col): Softmax(dim=-2) (sm_row): Softmax(dim=-3) ))[source]¶
wrap a GNN head to solve various matching problems.
- Parameters
head – a GNN head that estimate matching.
output_channels (int) – a number of matching results (Default: 1).
pre_interactor (Optional[CrossConcat]) – the interactor that first merge two input at the forward function (Default:
CrossConcat).stream_aggregator (Optional[Callable[[Tensor, Optional[Tensor], bool], Tuple[Tensor, Tensor, Tensor]]]) – the aggregator that merges estimation from all the streams (Default:
DualSoftmaxSqrt)net (Module) –
Example of Usage (1) Use WeaveNet to solve stable matching or any other combinatorial optimization:
from weavenet import TrainableMatchingModule, WeaveNet weave_net = TrainableMatchingModule( head = WeaveNet(2, [64]*6, [32]*6), ) for xab, xba_t in batches: y_pred = weave_net(xab, xba_t) loss = calc_loss(y_pred, y_true) ...
**Example of Usage (2) Use WeaveNet for matching extracted features **:
from weavenet import TrainableMatchingModule, WeaveNet from layers import CrossConcatVertexFeatures weave_net = TrainableMatchingModule( head = WeaveNet(2*vfeature_channels+2, [64]*6, [32]*6), pre_interactor = CrossConcatVertexFeatures(compute_similarity_cosine, softmax) ) for xa, xb, y_true in batches: xa = feature_extractor(xa) xb = feature_extractor(xb) y_pred = weave_net(xa, xb) loss = calc_loss(y_pred, y_true) ...
- forward(xab, xba_t)[source]¶
Try to match a bipartite agents on side a and b.
- Shape:
xab: \((\ldots, N, M, C)\)
xba_t: \((\ldots, N, M, C)\)
output: \((\ldots, N, M, \text{output_channels}')\) if dim_a = -3 and dim_b = -2. \(C' = 2*C\) if compute_similarity is None, \(C' = 2*C+1\) if only compute_similarity is set, and :math:`C’ = 2*C+2`if both compute_similarity and directional_normalization are set.
- Parameters
xab (Tensor) – vertex features on the side a or edge features directed from the side a to b.
xba_t (Tensor) – vertex features on the side b or edge features directed from the side b to a.
- Returns
A resultant tensor aggregated by stream_aggregator after processed through the network.
- Return type
Tuple[Tensor, Tensor, Tensor]
Experimental¶
- class weavenet.model.ExperimentalUnitListGenerator(input_channels, mid_channels_list, output_channels_list)[source]¶
A factory of experimental units. This is a sample class for user custom units.
- Parameters
input_channels (int) – input_channels for the first unit.
mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.
output_channels_list (List[int]) – output_channels for the units.
- class Encoder(in_channels, mid_channels, output_channels, **kwargs)[source]¶
A sample of experimental unit encoder.
- Parameters
in_channels (int) – input_channels for the first unit.
mid_channels_list – mid_channels for each point-net-based set encoders.
output_channels_list – output_channels for the units.
mid_channels (int) –
output_channels (int) –
Private Models¶
- class weavenet.model.Unit(encoder, order, normalizer=None, activator=None)[source]¶
Applies a series of process with encoder, normalizer, and activator in the directed order.
- Parameters
encoder (Module) – a trainable unit
order (typing_extensions.Literal[ena, nae, ean, ane]) – a direction of process order. [‘ena’|’nae’|’ean’|’ane’], e: encoder, a: activator, n: normalizer. e.g.) ‘ena’ applies encoder->normalizer->activator.
normalizer (Optional[Module]) – a normalizer, such as
BatchNormXXC.activator (Optional[Module]) – an activation function, such as
nn.PReLU.
- forward(x, dim_target)[source]¶
Applies unit process. This function is replaced to any of Unit._forward_* functions in
__init__()based on the argument order.- Shape:
x: \((\ldots, C)\)
- Parameters
x (Tensor) – a source features.
dim_target (int) – dimention of target vertex.
- Returns
A processed features.
- Return type
Tensor
- class weavenet.model.UnitListGenerator(input_channels, output_channels_list)[source]¶
A factory of units.
- Parameters
input_channels (int) – input_channels for the first unit.
output_channels_list (List[int]) – output_channels for the units.
- generate(interactor=None)[source]¶
Generates a list of units, assuming the interactor at each end of unit-process.
- Parameters
interactor (Optional[Interactor]) – a concrete class of
Interactor. Typically,CrossConcat. If None, assumes no interaction at each end of unit-process.- Returns
a list of units.
- Return type
List[Unit]
- class weavenet.model.WeaveNetUnitListGenerator(input_channels, mid_channels_list, output_channels_list)[source]¶
A factory of weavenet units.
- Parameters
input_channels (int) – input_channels for the first unit.
mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.
output_channels_list (List[int]) – output_channels for the units.