I have made a text classifier model but I want to train it on device too.
When text is classified wrong, user can make update the model on device.
Code :
//
// SpamClassifierHelper.swift
// LearningML
//
// Created by Himan Dhawan on 7/1/24.
//
import Foundation
import CreateMLComponents
import CoreML
import NaturalLanguage
enum TextClassifier : String {
case spam = "spam"
case notASpam = "ham"
}
class SpamClassifierModel {
// MARK: - Private Type Properties
/// The updated Spam Classifier model.
private static var updatedSpamClassifier: SpamClassifier?
/// The default Spam Classifier model.
private static var defaultSpamClassifier: SpamClassifier {
do {
return try SpamClassifier(configuration: .init())
} catch {
fatalError("Couldn't load SpamClassifier due to: \(error.localizedDescription)")
}
}
// The Spam Classifier model currently in use.
static var liveModel: SpamClassifier {
updatedSpamClassifier ?? defaultSpamClassifier
}
/// The location of the app's Application Support directory for the user.
private static let appDirectory = FileManager.default.urls(for: .applicationSupportDirectory,
in: .userDomainMask).first!
class var urlOfModelInThisBundle : URL {
let bundle = Bundle(for: self)
return bundle.url(forResource: "SpamClassifier", withExtension:"mlmodelc")!
}
/// The default Spam Classifier model's file URL.
private static let defaultModelURL = urlOfModelInThisBundle
/// The permanent location of the updated Spam Classifier model.
private static var updatedModelURL = appDirectory.appendingPathComponent("personalized.mlmodelc")
/// The temporary location of the updated Spam Classifier model.
private static var tempUpdatedModelURL = appDirectory.appendingPathComponent("personalized_tmp.mlmodelc")
// MARK: - Public Type Methods
static func predictLabelFor(_ value: String) throws -> (predication :String?, confidence : String) {
let spam = try NLModel(mlModel: liveModel.model)
let result = spam.predictedLabel(for: value)
let confidence = spam.predictedLabelHypotheses(for: value, maximumCount: 1).first?.value ?? 0
return (result,String(format: "%.2f", confidence * 100))
}
static func updateModel(newEntryText : String, spam : TextClassifier) throws {
guard let modelURL = Bundle.main.url(forResource: "SpamClassifier", withExtension: "mlmodelc") else {
fatalError("Could not find model in bundle")
}
// Create feature provider for the new image
let featureProvider = try MLDictionaryFeatureProvider(dictionary: ["label": MLFeatureValue(string: newEntryText), "text": MLFeatureValue(string: spam.rawValue)])
let batchProvider = MLArrayBatchProvider(array: [featureProvider])
let updateTask = try MLUpdateTask(forModelAt: modelURL, trainingData: batchProvider, configuration: nil, completionHandler: { context in
let updatedModel = context.model
let fileManager = FileManager.default
do {
// Create a directory for the updated model.
try fileManager.createDirectory(at: tempUpdatedModelURL,
withIntermediateDirectories: true,
attributes: nil)
// Save the updated model to temporary filename.
try updatedModel.write(to: tempUpdatedModelURL)
// Replace any previously updated model with this one.
_ = try fileManager.replaceItemAt(updatedModelURL,
withItemAt: tempUpdatedModelURL)
loadUpdatedModel()
print("Updated model saved to:\n\t\(updatedModelURL)")
} catch let error {
print("Could not save updated model to the file system: \(error)")
return
}
})
updateTask.resume()
}
/// Loads the updated Spam Classifier, if available.
/// - Tag: LoadUpdatedModel
private static func loadUpdatedModel() {
guard FileManager.default.fileExists(atPath: updatedModelURL.path) else {
// The updated model is not present at its designated path.
return
}
// Create an instance of the updated model.
guard let model = try? SpamClassifier(contentsOf: updatedModelURL) else {
return
}
// Use this updated model to make predictions in the future.
updatedSpamClassifier = model
}
}