Additionally, in case needed as well, here is my conversion script:
import torch
import coremltools as ct
import numpy as np
import logging
from ball_tracker_model import BallTrackerNet
def convert_to_coreml(model_path):
logging.basicConfig(level=logging.DEBUG)
model = BallTrackerNet()
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
example_input = torch.rand(1, 9, 360, 640)
# Trace the model to verify shapes
traced_model = torch.jit.trace(model, example_input)
model_coreml = ct.convert(
traced_model,
inputs=[
ct.TensorType(
name="input_frames",
shape=(1, 9, 360, 640),
dtype=np.float32,
)
],
convert_to="mlprogram",
minimum_deployment_target=ct.target.iOS15,
)
model_coreml.save("BallTracker.mlpackage")
return model_coreml
# Run conversion
try:
model = convert_to_coreml("balltrackerbest.pt")
print("Conversion successful!")
except Exception as e:
print(f"Conversion error: {str(e)}")
Thanks again!
Post
Replies
Boosts
Views
Activity
Yes, this is what I am seeing in Xcode.
xcrun coremlcompiler metadata path/to/model.mlpackage says the following:
[
{
"metadataOutputVersion" : "3.0",
"storagePrecision" : "Float16",
"outputSchema" : [
{
"hasShapeFlexibility" : "0",
"isOptional" : "0",
"dataType" : "Float32",
"formattedType" : "MultiArray (Float32 1 × 256 × 230400)",
"shortDescription" : "",
"shape" : "[1, 256, 230400]",
"name" : "var_462",
"type" : "MultiArray"
}
],
"modelParameters" : [
],
"specificationVersion" : 6,
"mlProgramOperationTypeHistogram" : {
"Cast" : 2,
"Conv" : 18,
"Relu" : 18,
"BatchNorm" : 18,
"Reshape" : 1,
"UpsampleNearestNeighbor" : 3,
"MaxPool" : 3
},
"computePrecision" : "Mixed (Float16, Float32, Int32)",
"isUpdatable" : "0",
"availability" : {
"macOS" : "12.0",
"tvOS" : "15.0",
"visionOS" : "1.0",
"watchOS" : "8.0",
"iOS" : "15.0",
"macCatalyst" : "15.0"
},
"modelType" : {
"name" : "MLModelType_mlProgram"
},
"userDefinedMetadata" : {
"com.github.apple.coremltools.source_dialect" : "TorchScript",
"com.github.apple.coremltools.source" : "torch==2.5.1",
"com.github.apple.coremltools.version" : "8.1"
},
"inputSchema" : [
{
"hasShapeFlexibility" : "0",
"isOptional" : "0",
"dataType" : "Float32",
"formattedType" : "MultiArray (Float32 1 × 9 × 360 × 640)",
"shortDescription" : "",
"shape" : "[1, 9, 360, 640]",
"name" : "input_frames",
"type" : "MultiArray"
}
],
"generatedClassName" : "BallTracker",
"method" : "predict"
}
]