I wanted to deploy some ViT models on an iPhone. I referred to https://machinelearning.apple.com/research/vision-transformers for deployment and wrote a simple demo based on the code from https://github.com/apple/ml-vision-transformers-ane. However, I found that the uncached load time on the phone is very long. According to the blog, the input is already aligned to 64 bytes, but the speed is still very slow. Is there any way to speed it up? This is my test case:
import torch
import coremltools as ct
import math
from torch import nn
class SelfAttn(torch.nn.Module):
def __init__(self, window_size, num_heads, dim, dim_out):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.dim = dim
self.dim_out = dim_out
self.q_proj = nn.Conv2d(
in_channels=dim,
out_channels=dim_out,
kernel_size=1,
)
self.k_proj = nn.Conv2d(
in_channels=dim,
out_channels=dim_out,
kernel_size=1,
)
self.v_proj = nn.Conv2d(
in_channels=dim,
out_channels=dim_out,
kernel_size=1,
)
def forward(self, x):
B, HW, C = x.shape
image_shape = (B, C, self.window_size, self.window_size)
x_2d = x.permute((0, 2, 1)).reshape(image_shape) # BCHW
x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2) # BC1L
q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d)
mh_q = torch.split(q, self.dim_out // self.num_heads, dim=1) # BC1L
mh_v = torch.split(
v_2d.reshape(B, -1, x_flat.shape[2], x_flat.shape[3]), self.dim_out // self.num_heads, dim=1
)
mh_k = torch.split(
torch.permute(k, (0, 3, 2, 1)), self.dim_out // self.num_heads, dim=3
)
scale_factor = 1 / math.sqrt(mh_q[0].size(1))
attn_weights = [
torch.einsum("bchq, bkhc->bkhq", qi, ki) * scale_factor
for qi, ki in zip(mh_q, mh_k)
]
attn_weights = [
torch.softmax(aw, dim=1) for aw in attn_weights
] # softmax applied on channel "C"
mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(attn_weights, mh_v)]
x = torch.cat(mh_x, dim=1)
return x
window_size = 8
path_batch = 1024
emb_dim = 96
emb_dim_out = 96
x = torch.rand(path_batch, window_size * window_size, emb_dim)
qkv_layer = SelfAttn(window_size, 1, emb_dim, emb_dim_out)
jit = torch.jit.trace(qkv_layer, (x))
mlmod_fixed_shape = ct.convert(
jit,
inputs=[
ct.TensorType("x", x.shape),
],
convert_to="mlprogram",
)
mlmodel_path = "test_ane.mlpackage"
mlmod_fixed_shape.save(mlmodel_path)
The uncached load took nearly 36 seconds, and it was just a single matrix multiplication.