Post

Replies

Boosts

Views

Activity

On device training of text classifier model
I have made a text classifier model but I want to train it on device too. When text is classified wrong, user can make update the model on device. Code : // // SpamClassifierHelper.swift // LearningML // // Created by Himan Dhawan on 7/1/24. // import Foundation import CreateMLComponents import CoreML import NaturalLanguage enum TextClassifier : String { case spam = "spam" case notASpam = "ham" } class SpamClassifierModel { // MARK: - Private Type Properties /// The updated Spam Classifier model. private static var updatedSpamClassifier: SpamClassifier? /// The default Spam Classifier model. private static var defaultSpamClassifier: SpamClassifier { do { return try SpamClassifier(configuration: .init()) } catch { fatalError("Couldn't load SpamClassifier due to: \(error.localizedDescription)") } } // The Spam Classifier model currently in use. static var liveModel: SpamClassifier { updatedSpamClassifier ?? defaultSpamClassifier } /// The location of the app's Application Support directory for the user. private static let appDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! class var urlOfModelInThisBundle : URL { let bundle = Bundle(for: self) return bundle.url(forResource: "SpamClassifier", withExtension:"mlmodelc")! } /// The default Spam Classifier model's file URL. private static let defaultModelURL = urlOfModelInThisBundle /// The permanent location of the updated Spam Classifier model. private static var updatedModelURL = appDirectory.appendingPathComponent("personalized.mlmodelc") /// The temporary location of the updated Spam Classifier model. private static var tempUpdatedModelURL = appDirectory.appendingPathComponent("personalized_tmp.mlmodelc") // MARK: - Public Type Methods static func predictLabelFor(_ value: String) throws -> (predication :String?, confidence : String) { let spam = try NLModel(mlModel: liveModel.model) let result = spam.predictedLabel(for: value) let confidence = spam.predictedLabelHypotheses(for: value, maximumCount: 1).first?.value ?? 0 return (result,String(format: "%.2f", confidence * 100)) } static func updateModel(newEntryText : String, spam : TextClassifier) throws { guard let modelURL = Bundle.main.url(forResource: "SpamClassifier", withExtension: "mlmodelc") else { fatalError("Could not find model in bundle") } // Create feature provider for the new image let featureProvider = try MLDictionaryFeatureProvider(dictionary: ["label": MLFeatureValue(string: newEntryText), "text": MLFeatureValue(string: spam.rawValue)]) let batchProvider = MLArrayBatchProvider(array: [featureProvider]) let updateTask = try MLUpdateTask(forModelAt: modelURL, trainingData: batchProvider, configuration: nil, completionHandler: { context in let updatedModel = context.model let fileManager = FileManager.default do { // Create a directory for the updated model. try fileManager.createDirectory(at: tempUpdatedModelURL, withIntermediateDirectories: true, attributes: nil) // Save the updated model to temporary filename. try updatedModel.write(to: tempUpdatedModelURL) // Replace any previously updated model with this one. _ = try fileManager.replaceItemAt(updatedModelURL, withItemAt: tempUpdatedModelURL) loadUpdatedModel() print("Updated model saved to:\n\t\(updatedModelURL)") } catch let error { print("Could not save updated model to the file system: \(error)") return } }) updateTask.resume() } /// Loads the updated Spam Classifier, if available. /// - Tag: LoadUpdatedModel private static func loadUpdatedModel() { guard FileManager.default.fileExists(atPath: updatedModelURL.path) else { // The updated model is not present at its designated path. return } // Create an instance of the updated model. guard let model = try? SpamClassifier(contentsOf: updatedModelURL) else { return } // Use this updated model to make predictions in the future. updatedSpamClassifier = model } }
1
0
507
Jul ’24