Skip to content

MultiModal

ClipEncoder

Bases: EmbetterBase

Clip model than can encode text and images.

Under the hood it just wraps around the implementation of sentence-transformers

Parameters:

Name Type Description Default
name

name of model, see available options

'clip-ViT-B-32'
device

manually override cpu/gpu device, tries to grab gpu automatically when available

None
quantize

turns on quantization

False
num_threads

number of treads for pytorch to use, only affects when device=cpu

None

The following model names should be supported:

  • clip-ViT-B-32
  • clip-ViT-B-16
  • clip-ViT-B-14
  • clip-ViT-B-32-multilingual-v1
Source code in embetter/multi/_clip.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ClipEncoder(EmbetterBase):
    """
    Clip model than can encode text and images.

    Under the hood it just wraps around the implementation of [sentence-transformers](https://sbert.net/docs/pretrained_models.html?highlight=clip)

    Arguments:
        name: name of model, see available options
        device: manually override cpu/gpu device, tries to grab gpu automatically when available
        quantize: turns on quantization
        num_threads: number of treads for pytorch to use, only affects when device=cpu

    The following model names should be supported:

    - `clip-ViT-B-32`
    - `clip-ViT-B-16`
    - `clip-ViT-B-14`
    - `clip-ViT-B-32-multilingual-v1`
    """

    def __init__(
        self, name="clip-ViT-B-32", device=None, quantize=False, num_threads=None
    ):
        if not device:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.name = name
        self.device = device
        self.tfm = SBERT(name, device=self.device)
        self.num_threads = num_threads
        self.quantize = quantize
        if quantize:
            self.tfm = quantize_dynamic(self.tfm, {Linear})
        if num_threads:
            if self.device.type == "cpu":
                torch.set_num_threads(num_threads)

    def transform(self, X, y=None):
        """Transforms the text into a numeric representation."""
        # Convert pd.Series objects to encode compatable
        if isinstance(X, pd.Series):
            X = X.to_numpy()

        return self.tfm.encode(X)