"""Implementation of the Cluster class.
Represents a group of sites of a given lattice. These are the building blocks
for a cluster basis of functions over configurational space.
__author__ = "Luis Barroso-Luque, William Davidson Richard"
import json
import os
from fnmatch import fnmatch
from functools import cached_property
from io import StringIO
import numpy as np
from monty.io import zopen
from monty.json import MSONable
from pymatgen.core import Lattice, Site
from pymatgen.core.structure import Composition, SiteCollection
from pymatgen.util.coord import is_coord_subset
from ruamel.yaml import YAML
from smol.cofe.space.constants import SITE_TOL
from smol.cofe.space.domain import Vacancy, get_site_spaces
class Cluster(SiteCollection, MSONable):
"""An undecorated (no occupancies) cluster.
Represented simply by a list of sites, its centroid, and the underlying
You probably never need to instantiate this class directly. Look at
ClusterSubspace to create orbits and clusters necessary for a CE.
frac_coords (ndarray): fractional coordinates of each site.
lattice (Lattice): Underlying lattice of cluster.
centroid (float): Geometric centroid of included sites.
id (int): ID of cluster.
Used to identify the Cluster in a given ClusterSubspace.
def __init__(self, site_spaces, frac_coords, lattice):
"""Initialize Cluster.
site_spaces (list of SiteSpace):
list of site spaces for the cluster
frac_coords (Sequence):
Sequence of frac coords for the site spaces
lattice (Lattice):
pymatgen Lattice object
frac_coords = np.array(frac_coords)
centroid = np.average(frac_coords, axis=0)
shift = np.floor(centroid)
self._centroid = centroid - shift
self._frac_coords = frac_coords - shift
self._sites = tuple(
Site(site_space.composition, coords)
for site_space, coords in zip(
site_spaces, lattice.get_cartesian_coords(frac_coords)
self._lattice = lattice
self.id = None
def centroid(self):
"""Return the centroid of cluster."""
return self._centroid
def frac_coords(self):
"""Return the fractional coordinates of cluster w.r.t the underlying lattice."""
return self._frac_coords
def diameter(self):
"""Get maximum distance between any 2 sites in the cluster."""
coords = self.lattice.get_cartesian_coords(self.frac_coords)
all_d2 = np.sum((coords[None, :, :] - coords[:, None, :]) ** 2, axis=-1)
return np.max(all_d2) ** 0.5
def radius(self):
"""Get half the maximum distance between any 2 sites in the cluster."""
return self.diameter / 2.0
def lattice(self):
"""Return the underlying lattice."""
return self._lattice
def sites(self):
"""Return the list of sites."""
return self._sites
def get_distance(self, i: int, j: int) -> float:
"""Return distance between sites at index i and j.
i: Index of first site
j: Index of second site
Distance between sites at index i and index j.
return self[i].distance(self[j])
def copy(self):
"""Return a copy of the cluster."""
new_sites = [Site(site.species, site.coords.copy()) for site in self.sites]
lattice = self.lattice.copy()
return Cluster.from_sites(new_sites, lattice)
def assign_ids(self, cluster_id):
"""Recursively assign IDs to clusters after initialization."""
self.id = cluster_id
return cluster_id + 1
def to(self, fmt: str = None, filename: str = None):
Output the cluster to a file or string.
this is basically a watered down version of pymatgen.Molecule.to
fmt (str): Format to output to. Defaults to JSON unless filename
is provided. If fmt is specified, it overrides whatever the
filename is. Options include "yaml" and "json" only.
Non-case sensitive.
filename (str): If provided, output will be written to a file. If
fmt is not specified, the format is determined from the
filename. Defaults is None, i.e. string output.
(str) if filename is None. None otherwise.
fmt = "" if fmt is None else fmt.lower()
fname = os.path.basename(filename or "")
if fmt == "json" or fnmatch(fname, "*.json*") or fnmatch(fname, "*.mson*"):
if filename:
with zopen(filename, "wt", encoding="utf8") as f:
return json.dump(self.as_dict(), f)
return json.dumps(self.as_dict())
if fmt == "yaml" or fnmatch(fname, "*.yaml*"):
yaml = YAML()
if filename:
with zopen(fname, "wt", encoding="utf8") as f:
return yaml.dump(self.as_dict(), f)
sio = StringIO()
yaml.dump(self.as_dict(), sio)
return sio.getvalue()
raise ValueError(f"Invalid format: `{str(fmt)}`")
def from_str(cls, input_string: str, fmt):
Read a cluster from a string.
input_string (str): String to parse.
fmt (str): Format to output to. Defaults to JSON unless filename
is provided. If fmt is specifies, it overrides whatever the
filename is. Options include "yaml", "json". Non-case sensitive.
if fmt == "json":
d = json.loads(input_string)
return cls.from_dict(d)
if fmt == "yaml":
yaml = YAML()
d = yaml.load(input_string)
return cls.from_dict(d)
raise ValueError(f"Invalid format: `{str(fmt)}`")
def from_file(cls, filename: str):
Read a cluster from a file. Supported formats are json and yaml only.
filename (str): The filename to read from.
filename = str(filename)
with zopen(filename, mode="rt", errors="replace") as file:
contents = file.read()
fname = os.path.basename(filename)
if fnmatch(fname, "*.json*") or fnmatch(fname, "*.mson*"):
return cls.from_str(contents, fmt="json")
if fnmatch(fname, "*.yaml*"):
return cls.from_str(contents, fmt="yaml")
raise ValueError("Cannot determine file type.")
def from_sites(cls, sites, lattice):
"""Create a cluster from a list of sites and lattice object."""
frac_coords = [lattice.get_fractional_coords(site.coords) for site in sites]
site_spaces = get_site_spaces(sites)
return cls(site_spaces, frac_coords, lattice)
def __eq__(self, other):
"""Check equivalency of clusters considering symmetry."""
if self.frac_coords.shape != other.frac_coords.shape:
return False
othersites = other.frac_coords + np.round(self.centroid - other.centroid)
return is_coord_subset(self.frac_coords, othersites, atol=SITE_TOL)
def __str__(self):
"""Pretty print a cluster."""
centroid_str = " ".join(
for j in self.lattice.get_cartesian_coords(self.centroid)
centroid_str += " -> " + " ".join(
[f"{j:0.6f}".rjust(12) for j in self.centroid]
outs = [
f"Diameter : {self.diameter:0.4f}",
f" Charge : {self.charge}",
f"Centroid : {centroid_str}",
f"Sites ({len(self)})",
site_outs = []
for i, site in enumerate(self):
" ".join(
" ".join([f"{j:0.6f}".rjust(12) for j in site.coords]),
" ->",
" ".join(
for j in self.lattice.get_fractional_coords(site.coords)
width = len(site_outs[0])
outs.append(width * "-")
outs += site_outs
return "\n".join(outs)
def __repr__(self):
"""Get cluster summary."""
centroid_str = "({:.4f}, {:.4f}, {:.4f})".format(
) + " -> [{:.4f}, {:.4f}, {:.4f}]".format(*self.centroid)
outs = [
f"No. sites: {len(self._sites)} Diameter: {self.diameter:0.4f}",
f"Centroid: {centroid_str}",
for s in self:
+ " -> [{:.4f}, {:.4f}, {:.4f}]".format(
return "\n".join(outs)
def from_dict(cls, d):
"""Create a Cluster from serialized dict."""
sites = [Site.from_dict(item) for item in d["sites"]]
# Force vacancies back to vacancies
for symbols, site in zip(d["vacancy_symbols"], sites):
site.species = Composition(
if spec.symbol not in symbols
else Vacancy(spec.symbol, spec.oxidation_state, spec.spin)
): val
for spec, val in site.species.items()
if spec.symbol not in symbols
return Cluster.from_sites(sites, Lattice.from_dict(d["lattice"]))
def as_dict(self):
"""Get json-serialization dict representation.
MSONable dict
cluster_d = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"lattice": self.lattice.as_dict(),
"sites": [site.as_dict() for site in self.sites],
"vacancy_symbols": [
[spec.symbol for spec in site.species if isinstance(spec, Vacancy)]
for site in self.sites
return cluster_d
# functions to render cluster using crystal-toolkit
from collections import defaultdict
from crystal_toolkit.core.legend import Legend
from crystal_toolkit.core.scene import Scene
def get_cluster_scene(
) -> Scene:
Create CTK objects for the lattice and sties.
self: Structure object
origin: coordinate of the origin
legend: Legend for the sites
CTK scene object to be rendered
origin = origin or list(-self.lattice.get_cartesian_coords([0.5, 0.5, 0.5]))
legend = legend or Legend(self)
primitives = defaultdict(list)
for site in self:
site_scene = site.get_scene(
for scene in site_scene.contents:
primitives[scene.name] += scene.contents
return Scene(
Scene(name=k, contents=v, origin=origin) for k, v in primitives.items()
Cluster.get_scene = get_cluster_scene
except ImportError: