Skip to content

UDF

User-defined functions (UDFs) can run batch processing on a chain to generate new chain values. The UDF will take fields from one or more rows of the data and output new fields. A UDF can run at scale on multiple workers and processes.

A UDF can be any Python function. The classes below are useful to implement a "stateful" UDF where a function is insufficient, such as when additional setup() or teardown() steps need to happen before or after the processing function runs.

UDFBase

UDFBase()

Bases: AbstractUDF

Base class for stateful user-defined functions.

Any class that inherits from it must have a process() method that takes input params from one or more rows in the chain and produces the expected output.

Optionally, the class may include these methods: - setup() to run code on each worker before process() is called. - teardown() to run code on each worker after process() completes.

Example
from datachain import C, DataChain, Mapper
import open_clip

class ImageEncoder(Mapper):
    def __init__(self, model_name: str, pretrained: str):
        self.model_name = model_name
        self.pretrained = pretrained

    def setup(self):
        self.model, _, self.preprocess = (
            open_clip.create_model_and_transforms(
                self.model_name, self.pretrained
            )
        )

    def process(self, file) -> list[float]:
        img = file.get_value()
        img = self.preprocess(img).unsqueeze(0)
        emb = self.model.encode_image(img)
        return emb[0].tolist()

(
    DataChain.from_storage(
        "gs://datachain-demo/fashion-product-images/images", type="image"
    )
    .limit(5)
    .map(
        ImageEncoder("ViT-B-32", "laion2b_s34b_b79k"),
        params=["file"],
        output={"emb": list[float]},
    )
    .show()
)
Source code in datachain/lib/udf.py
def __init__(self):
    self.params = None
    self.output = None
    self.params_spec = None
    self.output_spec = None
    self._contains_stream = None
    self._catalog = None
    self._func = None

process

process(*args, **kwargs)

Processing function that needs to be defined by user

Source code in datachain/lib/udf.py
def process(self, *args, **kwargs):
    """Processing function that needs to be defined by user"""
    if not self._func:
        raise NotImplementedError("UDF processing is not implemented")
    return self._func(*args, **kwargs)

setup

setup()

Initialization process executed on each worker before processing begins. This is needed for tasks like pre-loading ML models prior to scoring.

Source code in datachain/lib/udf.py
def setup(self):
    """Initialization process executed on each worker before processing begins.
    This is needed for tasks like pre-loading ML models prior to scoring.
    """

teardown

teardown()

Teardown process executed on each process/worker after processing ends. This is needed for tasks like closing connections to end-points.

Source code in datachain/lib/udf.py
def teardown(self):
    """Teardown process executed on each process/worker after processing ends.
    This is needed for tasks like closing connections to end-points.
    """

Aggregator

Aggregator()

Bases: UDFBase

Inherit from this class to pass to DataChain.agg().

Source code in datachain/lib/udf.py
def __init__(self):
    self.params = None
    self.output = None
    self.params_spec = None
    self.output_spec = None
    self._contains_stream = None
    self._catalog = None
    self._func = None

BatchMapper

BatchMapper()

Bases: UDFBase

Inherit from this class to pass to DataChain.batch_map().

Source code in datachain/lib/udf.py
def __init__(self):
    self.params = None
    self.output = None
    self.params_spec = None
    self.output_spec = None
    self._contains_stream = None
    self._catalog = None
    self._func = None

Generator

Generator()

Bases: UDFBase

Inherit from this class to pass to DataChain.gen().

Source code in datachain/lib/udf.py
def __init__(self):
    self.params = None
    self.output = None
    self.params_spec = None
    self.output_spec = None
    self._contains_stream = None
    self._catalog = None
    self._func = None

Mapper

Mapper()

Bases: UDFBase

Inherit from this class to pass to DataChain.map().

Source code in datachain/lib/udf.py
def __init__(self):
    self.params = None
    self.output = None
    self.params_spec = None
    self.output_spec = None
    self._contains_stream = None
    self._catalog = None
    self._func = None