Metrics¶
- weavenet.metric.binarize(m)[source]¶
Binarizes each matrix in a batch into the one-to-one format (if N=M). If N>M, N-M vertices will have no partner and vice versa.
- Shape:
m: \((B, N, M)\)
output: \((B, N, M)\)
- weavenet.metric.is_one2one(m)[source]¶
Checks whether each matrix in a batch m has no duplicated correspondence.
- Shape:
m: \((B, N, M)\)
output: \((B)\)
- weavenet.metric.is_stable(m, sab, sba_t)[source]¶
Checks whether each matrix in a batch m is a stable match or not.
- Shape:
m: \((B, N, M)\)
sab: \((B, N, M)\)
sab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
A binary bool vector.
- Return type
Tensor
- weavenet.metric.count_blocking_pairs(m, sab, sba_t)[source]¶
Counts the number of blocking pairs for each matrix in batch m.
- Shape:
m: \((B, N, M)\)
sab: \((B, N, M)\)
sab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
A count vector.
- Return type
Tensor
- weavenet.metric.sexequality_cost(m, cab, cba_t, pformat=PreferenceFormat.cost)[source]¶
Calculates sexequality costs.
- Shape:
m: \((B, N, M)\)
cab: \((B, N, M)\)
cab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
A cost vector.
- Return type
Tensor
- weavenet.metric.egalitarian_score(m, cab, cba_t, pformat=PreferenceFormat.cost)[source]¶
Calculates egalitarian score.
- Shape:
m: \((B, N, M)\)
cab: \((B, N, M)\)
cab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
A score vector.
- Return type
Tensor
- weavenet.metric.balance_score(m, cab, cba_t, pformat=PreferenceFormat.cost)[source]¶
Calculates egalitarian score.
- Shape:
m: \((B, N, M)\)
cab: \((B, N, M)\)
cab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
A score vector.
- Return type
Tensor
- weavenet.metric.calc_all_fairness_metrics(m, cab, cba_t, pformat=PreferenceFormat.cost)[source]¶
Calculates the three fairness scores (sex-equality, egalitarian score, and balance score).
- Shape:
m: \((B, N, M)\)
cab: \((B, N, M)\)
cab: \((B, M, N)\)
output: \((B)\)
- Parameters
- Returns
The three score vectors.
- Return type
Tuple[Tensor, Tensor, Tensor]