Sharing references between Python and Rust

Sharing references between Python and Rust

In 2018, the Mercurial project decided to use Rust to improve performance and maintainability of previous high-performance code, you can read more about it in the Oxidation Plan.

While one may argue that Rust took inspiration from Python in some aspects of its semantics, the two languages don't share a lot of similarities at a lower level. Rust strict memory borrowing rules and default immutability don't play very nice with some of Python's features: its dynamic typing, mutability rules, classes, garbage collecting, to name the big ones.

We have faced some interesting challenges when bridging the Python implementation with the new Rust code, and this is one that I have not found any literature about.

Technological stack

There are two main crates used for bridging CPython and Rust: rust-cpython and PyO3. The latter is a fork of the former that happened when rust-cpython was seemingly abandoned, and has more features like support for properties and an arguably nicer syntax thanks to (then unstable) procedural macros.

We are however using rust-cpython because PyO3 does not (yet) compile on stable Rust, but the idea behind this article is still relevant regardless of the bridge used.

The issue at hand

During the rewrite of some core parts of Mercurial, we had to present a class-like interface to Python that would run Rust code. More often than not, that class implemented __iter__, which requires Python to hold a reference to a Rust iterator. Whenever we faced that issue, we just copied the entire structure to something Python understands, which is terrible. But that was good enough for the purpose of continuing the rewrite.

However, as the frontier between Python and Rust got more defined, we knew that we couldn't wait any longer to solve that issue. My colleague Georges (over at blog.racinet.fr) took the opportunity of a long train trip to dig into shared references with a minimalist example.

A minimalist example

Let's implement a stripped-down version of Python's set that only works with int, for simplicity's sake.

The basic features are pretty easy to implement with rust-cpython and would look like this:

extern crate cpython;
use cpython::*;

use std::cell::RefCell;
use std::collections::HashSet;

type Inner = HashSet<u32>;

py_class!(class RustSet |py| {
    data hs: RefCell<Inner>;

    def __new__(_cls) -> PyResult<RustSet> {
        Self::create_instance(py, RefCell::new(Inner::new()))
    }

    def __contains__(&self, v: u32) -> PyResult<bool> {
        Ok(self.hs(py).borrow().contains(&v))
    }

    def add(&self, v: u32) -> PyResult<PyObject> {
        self.borrow_mut(py)?.insert(v);
        Ok(py.None())
    }

    def extend(&self, iterable: &PyObject) -> PyResult<PyObject> {
        let mut hs = self.hs(py).borrow_mut(py)?;
        for vobj in iterable.iter(py)? {
            hs.insert(vobj?.extract::<u32>(py)?);
        }
        Ok(py.None())
    }
});

The py_class! macro helps us define a Python class that we insert in a shared library that Python will treat like a normal Python module. How exactly that is done is explained in the rust-cpython docs and the repo for this experiment is available at the bottom of this article. In short, the hs data attribute will hold all the data in a Rust HashSet, the rest is basic encapsulation. Why a RefCell is used is also explained below.

So far, nothing really exciting is happening, but this allows us to define the following in a Python script using the generated .so as .shared_ref.

import sys

try:
    from .shared_ref import RustSet
except ImportError:
    sys.stderr.write(
        "Rust extension not found. Please run 'cargo build' first.\n"
    )
    sys.exit(1)


def test_basic():
    """Test basic scaffolding API: not needing to share refs."""
    rs = RustSet()
    rs.add(3)
    assert 3 in rs
    assert 4 not in rs
    rs.extend(x**2 for x in range(10))
    assert 4 in rs
    assert 81 in rs
    assert 65 not in rs


def run():
    test_basic()

Now on to the good stuff.

We need an iterator class that exposes __next__ and __iter__ to use from the Python side.

use std::collections::hash_set::Iter;

py_class!(class RustSetIterator |py| {
    data hs: RustSet;
    data it: RefCell<Iter<'static, u32>>;

    def __next__(&self) -> PyResult<Option<u32>> {
        Ok(self.it(py).borrow_mut().next().map(|r| *r))
    }

    def __iter__(&self) -> PyResult<Self> {
        Ok(self.clone_ref(py))  // `clone_ref` gives a new Python reference
    }

});

If RustSet is the Python class, the hs data attribute is an instance of the class. it is a Rust iterable (here from the hash_set module) of u32, just like our Inner type, with a 'static lifetime.

But, you might say, those integers are not static, they will be defined at runtime! Well, since Rust has no way of knowing what Python does, everything must be Send + 'static, hence the use of a RefCell whenever we want our Python object to hold a reference to Rust data.

However, just writing 'static does not magically make the compiler happy. For that, we need to enter the world of unsafe, one of the major use cases of is FFI code just like what we're doing.

Our RustSet class will then implement __iter__ the following way:

def __iter__(&self) -> PyResult<RustSetIterator> {
    let ptr = self.hs(py).as_ptr();
    let as_static: &'static Inner = unsafe {&*ptr};

    RustSetIterator::create_instance(
        py,
        self.clone_ref(py),
        RefCell::new(as_static.iter())
    )
}

Here, we take a raw pointer to our inner HashSet and (unsafely) tell Rust that its lifetime is 'static, which means that we are forgoing the help of the borrow-checker and we will have to do a bit of manual work to make all of that work.

For the time being, we have a basic iterator interface, which looks like:

def test_iter():
    rs = RustSet()
    start_count = sys.getrefcount(rs)  # should be 2 (see Python doc)
    rs.extend(range(4))

    it = iter(rs)
    assert sys.getrefcount(rs) == start_count + 1
    assert set(it) == {0, 1, 2, 3}
    del it

    assert sys.getrefcount(rs) == start_count
    it2 = iter(rs)
    del rs
    assert set(it2) == {0, 1, 2, 3}

    del it2

Nice. But this is just reading from the iterator, and our naive lifetime trick will not work for very long once mutation is involved.

Let's add a clear method to introduce mutation of the data now shared between Python and Rust.

def clear(&self) -> PyResult<PyObject> {
    let mut hs = self.hs(py).borrow_mut();
    hs.clear();
    // Force freeing of underlying memory to underline the risk of
    // segfault
    hs.shrink_to_fit();
    Ok(py.None())
}

As is said in the inline comment, we also force the inner HashSet to free its unused memory to trigger a segfault if Python still tries to access it.

This allows us to showcase the bug in a simple Python test:

def test_race_safety():
    rs = RustSet()
    # have Rust allocate some real amount of memory
    rs.extend(range(10000))
    it = iter(rs)

    # Trigger freeing the underlying memory
    rs.clear()

    next(it)  # segfault

Reference counting

We need to implement a higher-level system for ensuring memory safety. Let's add a reference counter to our RustSet to prevent mutation when Python holds a reference to the unsafe pointer.

Inside RustSet:

// ...
data leak_count: RefCell<usize>;
// ...
def __new__(_cls) -> PyResult<RustSet> {
    Self::create_instance(py, RefCell::new(Inner::new()), RefCell::new(0))
}
def add(&self, v: u32) -> PyResult<PyObject> {
      /// Changed from self.hs(py).borrow_mut()
    self.borrow_mut(py)?.insert(v);  
    // ...
}
def extend(&self, iterable: &PyObject) -> PyResult<PyObject> {
    /// Changed from self.hs(py).borrow_mut()
    let mut hs = self.borrow_mut(py)?;
      // ...
}
def __iter__(&self) -> PyResult<RustSetIterator> {
    RustSetIterator::create_instance(
        py,
        self.clone_ref(py),
        RefCell::new(self.leak_immutable(py).iter()),
        RefCell::new(false),
    )
}
def clear(&self) -> PyResult<PyObject> {
      /// Changed from self.hs(py).borrow_mut()
    let mut hs = self.borrow_mut(py)?;
    // ...
}

For our private functions, we can step out of the py_class! macro and create the following impl block:

/// Replaces the previous `use`
use std::cell::{RefCell, RefMut};

impl RustSet {
    fn leak_immutable(&self, py: Python) -> &'static Inner {
        let ptr = self.hs(py).as_ptr();
        *self.leak_count(py).borrow_mut().as_mut() += 1;
        unsafe { &*ptr }
    }
    fn borrow_mut<'a>(&'a self, py: Python<'a>) -> PyResult<RefMut<Inner>> {
        match *self.leak_count(py).borrow().as_ref() {
            0 => Ok(self.hs(py).borrow_mut()),
            _ => Err(AlreadyBorrowed::new(
                py,
                "Can't mutate while there are immutable \
                references in Python objects",
            )),
        }
    }
    fn decrease_leak_count(&self, py: Python) {
        *self.leak_count(py).borrow_mut().as_mut() -= 1;
    }
}

Here leak_immutable does the same lifetime extension "trick" that we did earlier, but also increments the leak_count data attribute to keep track of references. Of course, there is also a function to decrease the leak count.

The more interesting bit is the borrow_mut function, which encapsulates the same function of the inner HashSet and checks if the leak count is 0 before allowing Python to borrow, raising a custom Python exception otherwise. This very simple mechanic is only possible because we hold the reference to the GIL (via py, of type Python) and are thus guaranteed a single-threaded context.

Inside our RustSetIterator, we can replace our __next__ method to make use of our newly defined reference counter:

data done: RefCell<Box<bool>>;

def __next__(&self) -> PyResult<Option<u32>> {
    let mut done = self.done(py).borrow_mut();
    if *done.as_ref() {
        return Ok(None);
    }
    Ok(match self.it(py).borrow_mut().next() {
        None => {
            *done.as_mut() = true;
            self.hs(py).decrease_leak_count(py);
            None}
        Some(&r) => Some(r)
    })
}

We use a simple boolean to determine if the iterator is done, separating the lifetime of the iterator from that of its data, as shown in this example:

def test_race_safety():
    rs = RustSet()
    # have Rust allocate some real amount of memory
    rs.extend(range(10000))
    it = iter(rs)

    # Trigger freeing the underlying memory
    try:
        rs.clear()
    except AlreadyBorrowed:
        pass  # \o/
    else:
        raise AssertionError("Should not have been able to clear RustSet "
                             "instance while holding an iterator on it")
        next(it)  # that would be a segfault

    # Consume iterator
    assert len([x for x in it]) == 10000
    rs.clear()

    # the consumed iterator is actually still usable (doesn't need the
    # data anymore to raise StopIteration)
    assert [x for x in it] == []

We are unable to mutate our RustSet when an iterator still refers to its data, hurray!

Now, this differs from the Python behavior, let's see what happens if we replicate the example with a standard set:

>>> s = set([1, 2, 3])
>>> it = iter(s)
>>> s.clear()
>>> next(it)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Set changed size during iteration

Python does not complain when we mutate the set, only when we try to use the iterator, which is arguably a little too late. Most well written Python code should not have to adapt too much to this change to a stricter behavior.

However there exists some use cases where this "feature" it is useful; I've seen those use cases in Mercurial code where this implementation falls apart because some iterator is never depleted, so you add a del my_iterator which only really fixes the Python implementations that are reference counting... blablabla. Plus, "breaking user space" is always a bad idea unless really really important, so we will need to improve on that later.

Harnessing the garbage collector

The code above will run just fine given our conditions, but this example is a tad simplistic and our RustSet still needs a bit of work to play nice with Python in the general case. Indeed, our shared references might get garbage collected by Python and result in a memory leak (and a runtime bug) because Rust still thinks that those references exist.

As suggested in rust-cpython error messages, the proper way to implement Drop for PythonObject is to implement it on one of the data members. Let's create a struct that will manage our reference count on Drop:

struct RustSetLeakedRef {
    rs: RustSet,
}

impl RustSetLeakedRef {
    fn new(py: Python, rs: &RustSet) -> Self {
        RustSetLeakedRef {
            rs: rs.clone_ref(py),
        }
    }
}

impl Drop for RustSetLeakedRef {
    fn drop(&mut self) {
        let gil = Python::acquire_gil();
        let py = gil.python();
        self.rs.decrease_leak_count(py);
    }
}

This struct simply holds a Python reference to the RustSet instance, and decreases the leak count when dropped. Cool, now our RustSetIterator can use this new struct to make rust-cpython's Drop hook happy:

py_class!(class RustSetIterator |py| {
    data rs: RefCell<Option<RustSetLeakedRef>>;
    // ...

    def __next__(&self) -> PyResult<Option<u32>> {
        let mut rs_opt = self.rs(py).borrow_mut();
        if rs_opt.is_some() {
            Ok(match self.it(py).borrow_mut().next() {
                None => {
                    // replace Some(rs) by None, hence drop RustSetLeakedRef
                    rs_opt.take();
                    None
                }
                Some(&r) => Some(r)
            })
        } else {
            Ok(None)
        }
    }
    // ...
});

You will notice that we got rid of the done boolean in favor of an Option, which is much more idomatic.

Lastly, our RustSet needs to update its __iter__ method to use the new RustSetLeakedRef:

def __iter__(&self) -> PyResult<RustSetIterator> {
    RustSetIterator::create_instance(
        py,
        RefCell::new(Some(RustSetLeakedRef::new(py, &self))),
        RefCell::new(self.leak_immutable(py).iter()),
    )
}

We can now test this from the Python side:

def test_drop_before_end():
    rs = RustSet()
    # have Rust allocate some real amount of memory
    rs.extend(range(10))
    it = iter(rs)

    # Implementation of get_leak_count() left as
    # an exercise to the reader ;)
    assert rs.get_leak_count() == 1
    next(it)
    del it
    assert rs.get_leak_count() == 0

Works just as intended!

The only thing left to do to have complete support for garbage collection is to allow our RustSet to hold Python objects as well. The documentation of rust-cpython explains that the __traverse__ and __clear__ methods need to be implemented to dispatch the garbage collection to the inner objects.

As of now, the need has yet to present itself in the development of Mercurial... so we didn't do it. Sorry, maybe another time. :)

I've written a simple macro to automate parts of it, but I am not very happy with it yet; it's not as easy to use as I would like to and it's missing some features; leaking multiple attributes, automatic data attributes, method prefixing, etc.. I don't want to write a procedural macro since it's not battle tested yet.

If you have a better way of going about this problem, please let me know.

Links