Sparse Models¶
- class weavenet.sparse.model.TrainableMatchingModuleSp(net_sparse, pre_interactor_sparse=CrossConcat(), mask_selector=MaskSelectorByLinearInferenceOr(), output_channels=1, net_dense=None, pre_interactor_dense=CrossConcat(), stream_aggregator=DualSoftmaxSqrtSp())[source]¶
A variant of
TrainableMatchingModulethat treats sparse bipartite graph.- Parameters
net – a sparse GNN net that estimate matching.
mask_estimator – an algorithm to estimate mask from input of
forward()or from the output of pre_net.output_channels (int) – see
TrainableMatchingModule.pre_interactor –
TrainableMatchingModule.pre_net – a GNN pre_net applied before mask_estimator.
stream_aggregator (Optional[Callable[[Tensor, Tensor, Tensor, Optional[Tensor]], Tuple[Tensor, Tensor, Tensor]]]) – the aggregator that merges estimation from all the streams (Default:
DualSoftMaxSp)net_sparse (Module) –
pre_interactor_sparse (Optional[CrossConcat]) –
mask_selector (Module) –
net_dense (Optional[Module]) –
pre_interactor_dense (Optional[CrossConcat]) –
Example of Usage (1) Use WeaveNet to solve stable matching or any other combinatorial optimization:
from weavenet import TrainableMatchingModule, WeaveNet weave_net_sp = TrainableMatchingModule( net = WeaveNetSp(2*32, [64]*3, [32]*3), pre_net = WeaveNet(2, [64]*3, [32]*3), ) for xab, xba_t in batches: y_pred = weave_net_sp(xab, xba_t) loss = calc_loss(y_pred, y_true) ...
**Example of Usage (2) Use WeaveNet for matching extracted features **:
from sparse.weavenet import TrainableMatchingModuleSp, WeaveNetSp from weavenet import WeaveNet from layers import CrossConcatVertexFeatures weave_net_sp = TrainableMatchingModule( net = WeaveNetSp(2*32, [64]*3, [32]*3), pre_net = WeaveNet(2*vfeature_channels+2, [64]*3, [32]*3), pre_interactor = CrossConcatVertexFeatures(compute_similarity_cosine, softmax) ) for xa, xb, y_true in batches: xa = feature_extractor(xa) xb = feature_extractor(xb) y_pred = weave_net_sp(xa, xb) loss = calc_loss(y_pred, y_true) ...
- forward(xab, xba_t)[source]¶
Try to match a bipartite agents on side a and b.
- Shape:
xab: \((\ldots, N, M, C)\)
xba_t: \((\ldots, N, M, C)\)
output: \((\ldots, N, M, \text{output_channels}')\) if dim_a = -3 and dim_b = -2. \(C' = 2*C\) if compute_similarity is None, \(C' = 2*C+1\) if only compute_similarity is set, and :math:`C’ = 2*C+2`if both compute_similarity and directional_normalization are set.
- Parameters
xab (Tensor) – vertex features on the side a or edge features directed from the side a to b.
xba_t (Tensor) – vertex features on the side b or edge features directed from the side b to a.
- Returns
A resultant tensor aggregated by stream_aggregator after processed through the network.
- Return type
Tuple[Tensor, Tensor, Tensor]
Experimental¶
- class weavenet.sparse.model.ExperimentalUnitListGeneratorSp(input_channels, output_channels_list)[source]¶
Sparse version of
ExperimentalUnitListGenerator- Parameters
input_channels (int) – input_channels for the first unit.
mid_channels_list – mid_channels for each point-net-based set encoders.
output_channels_list (List[int]) – output_channels for the units.
Private Models¶
- class weavenet.sparse.model.UnitSp(encoder, order, normalizer=None, activator=None)[source]¶
a sparse version of
Unit.- Parameters
encoder (Module) – a trainable unit
order (typing_extensions.Literal[ena, nae, ean, ane]) – a direction of process order. [‘ena’|’nae’|’ean’|’ane’], e: encoder, a: activator, n: normalizer. e.g.) ‘ena’ applies encoder->normalizer->activator.
normalizer (Optional[Module]) – a normalizer, such as
nn.BatchNorm1d.activator (Optional[Module]) – an activation function, such as
nn.PReLU.
- forward(x, vertex_id)[source]¶
Applies unit process. This function is replaced to any of Unit._forward_* functions in
__init__()based on the argument order.- Shape:
x: \((\text{num_edges_in_batch}, C)\)
vertex_id: \((\text{num_edges_in_batch})\)
- Parameters
x (Tensor) – a edge features, which are flatten through a batch.
vertex_id (Tensor) – ID list of each edge (ID identifies src or target vetex of each edge).
- Returns
A processed features.
- Return type
Tensor
- class weavenet.sparse.model.WeaveNetUnitListGeneratorSp(input_channels, mid_channels_list, output_channels_list)[source]¶
Sparse version of
WeaveNetUnitListGenerator- Parameters
input_channels (int) – input_channels for the first unit.
mid_channels_list (List[int]) – mid_channels for each point-net-based set encoders.
output_channels_list (List[int]) – output_channels for the units.