Caching in huggingface datasets is broken. Sometimes.

There is a gothca with huggingface datasets I did not know about before. The issue is hard to reproduce, I have some guesses in the root cause analysis, and there is a relatively simple workaround for it. Imagine your code looks something like

import datasets
import my_lib

def map_fn(example,foo):
  example['a'] = my_lib.expensive_function_call(example, some_arg=foo)
  return example

ds = datasets.Dataset.load_dataset('parquet', datafiles=('file1.parquet',))['train']
ds = ds.map(map_fn, fn_args={'foo': 3})

Using the .load_dataset function activates caching. If you run this code twice, the .map call should only be run once (displaying the tqdm progress bar). The second time, it loads from cache. All good.

I my case, caching worked unreliably. On some runs it worked. Sometimes it did not. Renaming functions, adding comments, renaming the folder and renaming other files in the folder seemed to impact the caching. Strange!

Why it breaks

It all comes down to how huggingface decides to load from cache or not. In the implementation of .map, it starts by computing a so called fingerprint. If this fingerprint is associated with a dataset in cache, it loads from cache. If the fingerprint is new, it will recompute everything. The fingerprint is derived from the input dataset previous fingerprint (computed in .load_dataset), the mapping function, and the arguments passed in fn_args. A hash is computed on all of these. If the hash is not deterministic we can find strange behavior.

To compute the hash, a hasher from the xxhash library is instanciated. The original fingerprint, the function and the argments are all serialized into bytes, are fed into the hasher, and a digest from the hasher constitutes the new fingerprint.

The old fingerprint is deterministic. The hashing is deterministic. But the serialization of functions and arguments into bytes may not be deterministic. To do this, huggingface uses a custom subclass of the serialization tools in the dill library. Determinism is a long standing issue in dill, and the hugginface datasets customizations corrects for many of the known cases of nondeterminism, but it is not perfect. In my concrete example, it was the serialization of the function imported from my own library that did not serialize in a reliable way. I suppose there was a race condition somewhere, and depending on ‘irrelevant’ factors such as file names, folder structure et cetera triggered the race condition to manifest.

For a longer discussion on how and why dill is not deterministic, read this issue: function pickles are only quasi-deterministic, does this matter? It saddens me I could not find a relevant discussion in the documentation from huggingface on this.

The workaround

Since the function body cannot be hashed deterministically, I read the details in datasets fingerprint.py and made a poor-mans-version of that code that is guaranteed deterministic for my use case. It looks like


def make_manual_fingerprint(dataset: datasets.Dataset,fn,kwargs) -> str:
  """Manual fingerprint for Dataset.map"""
  hasher = xxhash.xxh64()
  hasher.update(b'123') # a hashing salt of sorts
  hasher.update(dataset._fingerprint.encode())

  # Feeding only function name into hasher
  hasher.update(fn.__name__.encode())
  
  # The non-deterministic dill variant would be
  # import datasets.utils._dill as dill 
  # fn_as_bytes = dill.dumps(fn)
  # hasher.update(fn_as_bytes)  

  for k, v in kwargs.items():
    # arguments must have string valued keys
    hasher.update(k.encode())
    # I use only integer arguments, which have deterministic serialization
    assert isinstance(v,int), f"Value {v} is not an int"
    hasher.update(str(v).encode())

  new_fingerprint = str(hasher.hexdigest())
  return new_fingerprint

We pass the new fingerprint to the Dataset.map function using the new_fingerprint argument:

ds = datasets.Dataset.load_dataset('parquet', datafiles=('file1.parquet',))['train']
fn = map_fn
kwargs = {'foo': 3}
df = ds.map(fn, fn_args=kwargs, new_fingerprint=make_manual_fingerprint(ds, fn, kwargs))

What is bad with this workaround is that one must manually invalidate the cache every time the map_fn is updated. The implmentation is more restrictive that it needs to, but it is okay for my use case.

Conclusion

Hashing data structures semantically is hard. One solution to the problem is to use some very simple hashing algorithm, and accept hash the collisions. Depending on your use case, that can be the simpler solution. Do not fix the general problem if you don’t have to.

Updated: