Rethinking the role of frames for SE(3)-invariant crystal structure modeling

ICLR 2025
1OMRON SINIC X Corporation2Osaka University* Equally contributed. The work was done while the first author was an intern at OMRON SINIC X.

TL;DR To make a GNN invariant to rotations, let's standardize the oriantations of local atomic environments represented by internal self-attention weights, instead of directly standardizing the global structure.

Overview

Crystal structure modeling with graph neural networks is essential for various applications in materials informatics, and capturing SE(3)-invariant geometric features is a fundamental requirement for these networks. A straightforward approach is to model with orientation-standardized structures through structure-aligned coordinate systems, or ‟frames.” However, unlike molecules, determining frames for crystal structures is challenging due to their infinite and highly symmetric nature. In particular, existing methods rely on a statically fixed frame for each structure, determined solely by its structural information, regardless of the task under consideration. Here, we rethink the role of frames, questioning whether such simplistic alignment with the structure is sufficient, and propose the concept of dynamic frames. While accommodating the infinite and symmetric nature of crystals, these frames provide each atom with a dynamic view of its local environment, focusing on actively interacting atoms. We demonstrate this concept by utilizing the attention mechanism in a recent transformer-based crystal encoder, resulting in a new architecture called CrystalFramer. Extensive experiments show that CrystalFramer outperforms conventional frames and existing crystal encoders in various crystal property prediction tasks.

Problem

Crystal structure

Crystal structures are periodic arrangements of atoms in 3D space, serving as the source codes for diverse materials, such as permanent magnets, battery materials, and superconductors.

Crystal structure in 2D space
Crystal structure in 2D space

A crystal structure is typically described by its repeatable 3D slice called a unit cell. We assume a unit cell consisting of NN atoms and denote it as (A,P,L)(A, P, L):

  • A=[a1,a2,,aN]N1×NA = [a_1, a_2, \cdots, a_N] \in \mathbb{N}^{1 \times N}: the species (atomic numbers) of unit cell atoms.
  • P=[p1,p2,,pN]R3×NP = [\bm{p}_1, \bm{p}_2, \cdots, \bm{p}_N] \in \mathbb{R}^{3 \times N}: the 3D Cartesian coordinates of unit cell atoms.
  • L=[1,2,3]R3×3L = [\bm{\ell}_1, \bm{\ell}_2, \bm{\ell}_3] \in \mathbb{R}^{3 \times 3}: lattice vectors that define periodic unit-cell translations in 3D space.

By tiling the unit cell to fill 3D space, the species and positions of atoms in the crystal structure are determined as follows.

A^={ai(n)ai(n)=ai,nZ3,1iN}P^={pi(n)pi(n)=pi+Ln,nZ3,1iN}\begin{align*}\hat{A} &= \{a_{i(\bm{n})} | a_{i(\bm{n})}=a_i, \bm{n}\in\mathbb{Z}^3, 1\leq i \leq N\}\\ \hat{P} &= \{\bm{p}_{i(\bm{n})} | \bm{p}_{i(\bm{n})}=\bm{p}_i+L\bm{n}, \bm{n}\in\mathbb{Z}^3, 1\leq i \leq N\}\end{align*}

Here, we use ii to denote the ii -th atom in the unit cell, and use i(n)i(\bm{n}) to denote its duplicate by the 3D translation: Ln=n11+n22+n33L\bm{n} = n_1\bm{\ell}_1 + n_2\bm{\ell}_2 + n_3\bm{\ell}_3. We use jj and j(n)j(\bm{n}) similarly.

SE(3)-invariant structural modeling

We consider the problem of estimating the physical state of a given crystal structure, assuming that the state remains invariant under rigid transformations (i.e., rotations and translations). Such a state typically corresponds to material properties, such as formation energy and bandgap.

We represent the state of a crystal structure by a set of abstract atom-wise state features for the unit-cell atoms:

X=[x1,x2,,xN]Rd×N.X = [\bm{x}_1, \bm{x}_2, \cdots, \bm{x}_N] \in \mathbb{R}^{d \times N}.

As input to a graph neural network (GNN), these features are usually initialized via atom embeddings:

X(0)AtomEmbedding(A),X^{(0)} \gets \text{AtomEmbedding}(A),

which only symbolically represent atomic species. They are then evolved through message-passing layers

X(t+1)f(t)(X(t),P,L)X^{(t+1)} \gets f^{(t)}(X^{(t)}, P, L)

to eventually reflect the atomic states appropriate for a target task.

Challenges in SE(3)-invariant GNNs

There are several approaches to ensuring SE(3) invariance in GNNs:

  • Invariant features: Leveraging inherently invariant geometric features, such as interatomic distances pjpi\|\bm{p}_j - \bm{p}_i\| and angles between triplets cos(pjpi,pkpi)\cos(\bm{p}_j - \bm{p}_i, \bm{p}_k - \bm{p}_i), ensures SE(3) invariance. However, fully distance-based models have limited expressive power, and incorporating three-body interactions significantly increases computational complexity.
  • Frames: Another straightforward approach is to standardize the orientation of a given structure through a structure-aligned coordinate system called a frame. However, determining frames for crystal structures is challenging due to their infinite and highly symmetric nature.

We explore a new frame-based methodology to incorporate richer yet invariant structural information beyond distances.

Ideas

What is the role of frames?

Surely, it is to standardize the orientations of given structures so that GNN models can directly exploit 3D coordinate information as invariant geometric features.

―― Is that all?

Let’s dig deeper into how frames work in a GNN, whose message-passing layers are assumed to include the following general operation:

xi=j=1NnZ3wij(n)Weightfij(n)(xj(n),P^)Message.\bm{x}'_i = \sum_{j=1}^{N} \sum_{\bm{n}\in \mathbb{Z}^3} \,\, \underbrace{w_{ij(\bm{n})}}_{\text{Weight}} \,\, \underbrace{\bm{f}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, \hat{P})}_{\text{Message}}.

This equation describes that the state x\bm{x} of each unit-cell atom ii is updated by receiving abstract influences or ‟messages”, fij(n)\bm{f}_{i\gets j(\bm{n})}, from atoms j(n)j(\bm{n}) in the structure, weighted by scalars wij(n)w_{ij(\bm{n})}. In recent transformer models, these weights are determined dynamically via self-attention mechanisms.

Distance-based GNNs ensure SE(3) invariance by simply formulating fij(n)\bm{f}_{i\gets j(\bm{n})} with the interatomic distance, rij(n)=pj(n)pir_{ij(\bm{n})} = \|\bm{p}_{j(\bm{n})} - \bm{p}_i \|, as follows:

fij(n)dist(xj(n),P^):=hij(n)(xj(n),rij(n)).\bm{f}^{\text{dist}}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, \hat{P}) := \bm{h}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, r_{ij(\bm{n})} ).

The role of frames is to offer, for the design of the message function fij(n)\bm{f}_{i\gets j(\bm{n})}, more informative invariant features beyond the distance through a structure-aligned coordinate system FR3×3F \in \mathbb{R}^{3 \times 3}, as follows:

fij(n)frame(xj(n),P^):=hij(n)(xj(n),rij(n),Frij(n)),\bm{f}^{\text{frame}}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, \hat{P}) := \bm{h}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, r_{ij(\bm{n})}, F\bm{r}_{ij(\bm{n})} ),

where the frame-projected relative position Frij(n)F\bm{r}_{ij(\bm{n})} remains invariant under global rotations and translations for the crystal structure P^\hat{P}.

Dynamic frames

Given that the end-users of frames are the message functions fij(n)\bm{f}_{i\gets j(\bm{n})} in GNNs, shouldn't we tailor a frame for each message function in each layer so that the function receives a better-normalized structure?

―― We pursue this idea by introducing the concept of dynamic frames.

In each message passing layer, the target atom ii receives more influences from atoms j(n)j(\bm{n}) with larger weights wij(n)w_{ij(\bm{n})}, and no influence from atoms j(n)j(\bm{n}) with zero weights. This means that, when updating the state of atom ii, this atom has its own partial and local view of the structure P^\hat{P} through weights wij(n)w_{ij(\bm{n})} acting as a mask on the structure.

Dynamic frame in 2D space
Dynamic frame in 2D space

As a dynamic frame, we therefore construct an atom-wise frame FiF_i for each target atom ii by using this masked view of the structure P^\hat{P} with weights wi\bm{w}_{i}, as follows:

FiFrameConstructioni(P^,wi).F_i \gets \text{FrameConstruction}_i(\hat{P}, \bm{w}_{i}).

Typically, we define an orthonormal basis Fi=[e1,e2,e3]TF_i = [\bm{e}_1, \bm{e}_2, \bm{e}_3]^T as a frame, where the first and second axes point towards the primary and secondary influential directions of interatomic interactions. (See the paper for detailed definitions.)

This dynamic frame is then used to project the relative position vectors rij(n)\bm{r}_{ij(\bm{n})} in order to derive the messages for the target atom ii, as follows:

fij(n)dynamic(xj(n),P^):=hij(n)(xj(n),rij(n),Firij(n)).\bm{f}^{\text{dynamic}}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, \hat{P}) := \bm{h}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, r_{ij(\bm{n})}, F_i \bm{r}_{ij(\bm{n})} ).

Importantly, our dynamic frames are constructed with the entire structure P^\hat{P}, rather than with a specific unit-cell representation (P,L)(P, L). Thus, our dynamic frames are invariant under the unit-cell variations within the same crystal structure.

CrystalFramer

We demonstrate the proposed concept of dynamic frames by utilizing the Crystalformer architecture (Taniai et al., ICLR 2024). Crystalformer employs the standard softmax self-attention for message passing, which is formulated as infinitely connected distance-decay attention as follows:

xi=j=1NnZ3wij(n)fij(n)(xj(n),P^)=j=1NnZ31Ziexp(qiTkjdKrij(n)22σi2)(vj+ψij(n)).\begin{align*}\bm{x}'_i &= \sum_{j=1}^{N} \sum_{\bm{n}\in \mathbb{Z}^3} \color{#C00000} \,\, w_{ij(\bm{n})} \,\, \color{#0070C0} \bm{f}_{i\gets j(\bm{n})}(\bm{x}_{j(\bm{n})}, \hat{P}) \\ &= \sum_{j=1}^N\sum_{\bm{n}\in \mathbb{Z}^3} \color{#C00000} {\frac{1}{Z_i} \exp\left(\frac{{\bm{q}_i^T \bm{k}_{j}}}{\sqrt{d_K}} - \frac{\|\bm{r}_{ij(\bm{n})}\|^2}{2\sigma_i^2}\right)} \color{#0070C0} {\left(\bm{v}_{j} +\bm{\psi}_{ij(\bm{n})}\right)}.\end{align*}

Here, query q\bm{q}, key k\bm{k}, and value v\bm{v} are linear projections of the current state x\bm{x}. Scalar ZiZ_i is the normalizer of softmax attention weights. Vector ψij(n)\bm{\psi}_{ij(\bm{n})} is a geometric relative position encoding for atoms ii and j(n)j(\bm{n}).

Originally, ψij(n)\bm{\psi}_{ij(\bm{n})} simply encodes the scalar distance rij(n)r_{ij(\bm{n})} via a linear projection of Gaussian basis functions (GBFs). In this work, we enhance the model's expressive power by incorporating frame-based geometric features into the Crystalformer's relative position encoding ψij(n)\bm{\psi}_{ij(\bm{n})}. This results in a new architecture CrystalFramer.

Frame-based invariant features

Given the unit direction vector rˉij(n)=rij(n)/rij(n)\bar{\bm{r}}_{ij(\bm{n})} = \bm{r}_{ij(\bm{n})} / r_{ij(\bm{n})}, we obtain its invariant representation θij(n)=Firˉij(n)\bm{\theta}_{ij(\bm{n})} = F_i \bar{\bm{r}}_{ij(\bm{n})}, where the kk -th component represents the cosine of the angle between the kk -th axis and the direction:

θij(n)(k)=ekrˉij(n).\theta_{ij(\bm{n})}^{(k)} = \bm{e}_k \cdot \bar{\bm{r}}_{ij(\bm{n})}.

Using GBFs b(x)\bm{b}(x) as a mapping from a scalar to a vector, we linearly combine the distance-based and three angle-based edge features, as follows:

ψij(n)=W0bdist(rij(n))+k=1,2,3Wkbangl(θij(n)(k)).\bm{\psi}_{ij(\bm{n})} = W_0 \bm{b}_\text{dist}\left(r_{ij(\bm{n})}\right) + \sum_{k=1,2,3} W_k\bm{b}_\text{angl}\left(\theta_{ij(\bm{n})}^{(k)}\right).

This ψij(n)\bm{\psi}_{ij(\bm{n})} as a whole essentially encodes the 3D relative position vector, rij(n)=pj(n)pi\bm{r}_{ij(\bm{n})} = \bm{p}_{j(\bm{n})} - \bm{p}_i.

Architecture

Below is the architecture of CrystalFramer, where we have introduced dynamic frame construction and frame-based edge features, as highlighted in the figure.

CrystalFramer architecture
CrystalFramer architecture

Given the multi-head self-attention mechanism, we dynamically construct a frame for each target atom, head, and layer during the self-attention operation.

Property Prediction Benchmarks

We evaluated the performance of CrystalFramer using two types of dynamic frames: weighted PCA frames and max frames. We compared these with existing crystal frames (PCA frames and lattice frames) and other state-of-the-art crystal encoders. For evaluation, we used three datasets: JARVIS (55,723 materials), Materials Project (69,239 materials), and OQMD (817,636 materials).

JARVIS dataset

E formE totalBG (OPT)BG (MBJ)E hull
Matformer (Yan et al., 2022)0.03250.0350.1370.300.064
PotNet (Lin et al., 2023)0.02940.0320.1270.270.055
eComFormer (Yan et al., 2024)0.02840.0320.1240.280.044
iComFormer (Yan et al., 2024)0.02720.02880.1220.260.047
Crystalformer (Taniai et al., 2024)0.03060.03200.1280.2740.0463
─ w/ PCA frames (Duval et al., 2023)0.03250.03340.1440.2920.0568
─ w/ lattice frames (Yan et al., 2024)0.03020.03230.1250.2740.0531
─ w/ static local frames0.02850.02920.1220.2610.0444
─ w/ weighted PCA frames (proposed)0.02870.03050.1260.2790.0444
─ w/ max frames (proposed)0.02630.02790.1170.2420.0471

Materials Project dataset

E formBGBulk modulusShear modulus
Matformer (Yan et al., 2022)0.0210.2110.0430.073
PotNet (Lin et al., 2023)0.01880.2040.0400.065
eComFormer (Yan et al., 2024)0.01820.2020.04170.0729
iComFormer (Yan et al., 2024)0.01830.1930.03800.0637
Crystalformer (Taniai et al., 2024)0.01860.1980.03770.0689
─ w/ PCA frames (Duval et al., 2023)0.01970.2170.04240.0719
─ w/ lattice frames (Yan et al., 2024)0.01940.2120.03890.0720
─ w/ static local frames0.01780.1910.03540.0708
─ w/ weighted PCA frames (proposed)0.01970.2140.04230.0715
─ w/ max frames (proposed)0.01720.1850.03380.0677

OQMD dataset

# BlocksE formBGE hull
Crystalformer (baseline)40.021150.060280.06759
CrystalFramer (max frames)40.018710.058050.06607
Crystalformer (baseline)80.021040.059860.06690
CrystalFramer (max frames)80.017780.057850.06454

Overall, CrystalFramer significantly improves the baseline performance of CrystalFormer and outperforms most existing methods across various tasks and datasets.

Visual Analysis

MgSnF4 C3N4

Max frames capture local motiffs around the target atom, while weighted PCA frames look at the structure over broader areas. Both types of frames tend to focus on close neighbors in shallow layers and relatively distant neighbors in deeper layers.

Citation

@inproceedings{ito2025crystalframer,
  title     = {Rethinking the role of frames for SE(3)-invariant crystal structure modeling},
  author    = {Yusei Ito and 
               Tatsunori Taniai and
               Ryo Igarashi and
               Yoshitaka Ushiku and
               Kanta Ono},
  booktitle = {The Thirteenth International Conference on Learning Representations (ICLR 2025)},
  year      = {2025},
  url       = {https://openreview.net/forum?id=gzxDjnvBDa}
}