Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion Lib/test/test_free_threading/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import threading
import unittest
from collections import deque
from collections import Counter, deque
from copy import copy
from test.support import threading_helper

Expand Down Expand Up @@ -49,5 +50,21 @@ def mutate():
)


class TestCounter(unittest.TestCase):
def test_update_concurrent(self):
# gh-151633: concurrent Counter.update calls must not cause use-after-free
# under free-threading.
NTHREADS = 4
PER_THREAD = 5000
c = Counter()
data = ['x'] * PER_THREAD
threads = [threading.Thread(target=c.update, args=(data,))
for _ in range(NTHREADS)]
for t in threads:
t.start()
for t in threads:
t.join()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a race condition under free-threading when multiple threads update the same :class:`~collections.Counter` concurrently.
45 changes: 28 additions & 17 deletions Modules/_collectionsmodule.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_dict.h" // _PyDict_GetItem_KnownHash()
#include "pycore_dict.h" // _PyDict_GetItemRef_KnownHash_LockHeld()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_pyatomic_ft_wrappers.h"
Expand Down Expand Up @@ -2595,24 +2595,35 @@ _collections__count_elements_impl(PyObject *module, PyObject *mapping,
goto done;
}

oldval = _PyDict_GetItem_KnownHash(mapping, key, hash);
if (oldval == NULL) {
if (PyErr_Occurred())
goto done;
if (_PyDict_SetItem_KnownHash(mapping, key, one, hash) < 0)
goto done;
} else {
/* oldval is a borrowed reference. Keep it alive across
PyNumber_Add(), which can execute arbitrary user code and
mutate (or even clear) the underlying dict. */
Py_INCREF(oldval);
int found;
int cs_err = 0;
Py_BEGIN_CRITICAL_SECTION(mapping);
found = _PyDict_GetItemRef_KnownHash_LockHeld(
(PyDictObject *)mapping, key, hash, &oldval);
if (found < 0) {
cs_err = -1;
}
else if (found == 0) {
if (_PyDict_SetItem_KnownHash_LockHeld(
(PyDictObject *)mapping, key, one, hash) < 0) {
cs_err = -1;
}
}
else {
newval = PyNumber_Add(oldval, one);
Py_DECREF(oldval);
if (newval == NULL)
goto done;
if (_PyDict_SetItem_KnownHash(mapping, key, newval, hash) < 0)
goto done;
Py_CLEAR(newval);
if (newval == NULL) {
cs_err = -1;
}
else if (_PyDict_SetItem_KnownHash_LockHeld(
(PyDictObject *)mapping, key, newval, hash) < 0) {
cs_err = -1;
}
}
Py_END_CRITICAL_SECTION();
Py_CLEAR(newval);
if (cs_err < 0) {
goto done;
}
Py_DECREF(key);
}
Expand Down
Loading