|
@@ -0,0 +1,356 @@
|
|
|
+"use strict";
|
|
|
+
|
|
|
+// 常量定义
|
|
|
+const NUM_CLASSES = 8;
|
|
|
+const IMAGE_SIZE = 227; // MobileNet V1 通常需要 227x227 或 224x224,请根据实际模型调整
|
|
|
+const TOPK = 3;
|
|
|
+
|
|
|
+class Main {
|
|
|
+ // --- 只有一个构造函数 ---
|
|
|
+ constructor() {
|
|
|
+ this.infoTexts = [];
|
|
|
+ this.trainingClass = -1; // 用于标记哪个按钮被点击以进行训练
|
|
|
+ this.currentImageElement = null; // 当前用于预览和处理的图片元素
|
|
|
+ this.isAppReady = false;
|
|
|
+ this.trainingButtons = []; // *** 在这里初始化 trainingButtons ***
|
|
|
+
|
|
|
+ // 获取 DOM 元素
|
|
|
+ this.fileInput = document.getElementById("file-input");
|
|
|
+ this.imagePreview = document.getElementById("image-preview");
|
|
|
+ this.imagePlaceholder = document.getElementById("preview-placeholder");
|
|
|
+ this.classifyButton = document.getElementById("classify-button");
|
|
|
+ this.controlsContainer = document.getElementById("controls-container");
|
|
|
+ this.statusDiv = document.getElementById("status");
|
|
|
+
|
|
|
+ // 使用 CDN 加载的全局变量
|
|
|
+ this.tf = tf; // 来自 @tensorflow/tfjs
|
|
|
+ this.knnClassifier = knnClassifier; // 来自 @tensorflow-models/knn-classifier
|
|
|
+ this.mobilenet = mobilenet; // 来自 @tensorflow-models/mobilenet
|
|
|
+
|
|
|
+ // 初始化页面和模型
|
|
|
+ this.bindPage();
|
|
|
+
|
|
|
+ // 创建训练按钮和信息文本 (现在可以安全调用了)
|
|
|
+ this.createTrainingControls();
|
|
|
+
|
|
|
+ // 绑定事件监听器
|
|
|
+ this.bindEventListeners();
|
|
|
+ }
|
|
|
+ // --- 构造函数结束 ---
|
|
|
+
|
|
|
+ updateStatus(message) {
|
|
|
+ this.statusDiv.innerText = message;
|
|
|
+ console.log(message); // 同时在控制台输出
|
|
|
+ }
|
|
|
+
|
|
|
+ createTrainingControls() {
|
|
|
+ for (let i = 0; i < NUM_CLASSES; i++) {
|
|
|
+ const div = document.createElement("div");
|
|
|
+ div.classList.add("class-controls"); // 添加 CSS 类
|
|
|
+ this.controlsContainer.appendChild(div);
|
|
|
+
|
|
|
+ // 创建训练按钮
|
|
|
+ const button = document.createElement("button");
|
|
|
+ button.innerText = `训练物品 ${i}`;
|
|
|
+ button.disabled = true; // 初始禁用,直到模型加载完成
|
|
|
+ div.appendChild(button);
|
|
|
+
|
|
|
+ // 点击按钮时,使用当前预览的图片进行训练
|
|
|
+ button.addEventListener("click", async () => {
|
|
|
+ if (
|
|
|
+ this.imagePreview.style.display === "none" ||
|
|
|
+ !this.imagePreview.src ||
|
|
|
+ this.imagePreview.src.endsWith("#")
|
|
|
+ ) {
|
|
|
+ alert("请先选择一张图片再进行训练!");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ this.updateStatus(`正在为物品 ${i} 添加样本...`);
|
|
|
+ await this.addImageForTraining(i);
|
|
|
+ this.updateStatus(`物品 ${i} 的样本已添加。`);
|
|
|
+ });
|
|
|
+
|
|
|
+ // 创建信息文本
|
|
|
+ const infoText = document.createElement("span");
|
|
|
+ infoText.innerText = " 0 个学习样例";
|
|
|
+ div.appendChild(infoText);
|
|
|
+ this.infoTexts.push(infoText);
|
|
|
+ this.trainingButtons.push(button); // 保存按钮引用以便启用/禁用
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ bindEventListeners() {
|
|
|
+ // 文件选择事件
|
|
|
+ this.fileInput.addEventListener("change", (event) => {
|
|
|
+ const file = event.target.files[0];
|
|
|
+ if (file && file.type.startsWith("image/")) {
|
|
|
+ const reader = new FileReader();
|
|
|
+ reader.onload = (e) => {
|
|
|
+ this.imagePreview.src = e.target.result;
|
|
|
+ this.imagePreview.style.display = "block";
|
|
|
+ this.imagePlaceholder.style.display = "none"; // 隐藏占位符
|
|
|
+ this.classifyButton.disabled = !this.isAppReady; // 启用识别按钮(如果应用已就绪)
|
|
|
+ // 启用训练按钮
|
|
|
+ this.trainingButtons.forEach(
|
|
|
+ (btn) => (btn.disabled = !this.isAppReady)
|
|
|
+ );
|
|
|
+
|
|
|
+ this.updateStatus("图片已加载,可以进行识别或训练。");
|
|
|
+ };
|
|
|
+ reader.readAsDataURL(file);
|
|
|
+ } else {
|
|
|
+ // 清除预览并禁用按钮
|
|
|
+ this.imagePreview.src = "#";
|
|
|
+ this.imagePreview.style.display = "none";
|
|
|
+ this.imagePlaceholder.style.display = "block"; // 显示占位符
|
|
|
+ this.classifyButton.disabled = true;
|
|
|
+ this.trainingButtons.forEach((btn) => (btn.disabled = true));
|
|
|
+
|
|
|
+ if (file) {
|
|
|
+ alert("请选择一个图片文件!");
|
|
|
+ this.updateStatus("请选择一个图片文件。");
|
|
|
+ } else {
|
|
|
+ this.updateStatus("未选择图片。");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ // 识别按钮点击事件
|
|
|
+ this.classifyButton.addEventListener("click", async () => {
|
|
|
+ if (
|
|
|
+ this.imagePreview.style.display === "none" ||
|
|
|
+ !this.imagePreview.src ||
|
|
|
+ this.imagePreview.src.endsWith("#")
|
|
|
+ ) {
|
|
|
+ alert("请先选择一张图片再进行识别!");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ await this.classifyCurrentImage();
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ async bindPage() {
|
|
|
+ try {
|
|
|
+ this.updateStatus("正在初始化应用...");
|
|
|
+ // 创建 KNN 分类器实例
|
|
|
+ this.knn = this.knnClassifier.create();
|
|
|
+ this.updateStatus("KNN 分类器已创建");
|
|
|
+
|
|
|
+ // 加载 MobileNet 模型
|
|
|
+ this.updateStatus("正在加载 MobileNet 模型...");
|
|
|
+ try {
|
|
|
+ this.mobilenetModel = await this.mobilenet.load({
|
|
|
+ version: 1,
|
|
|
+ alpha: 1.0,
|
|
|
+ });
|
|
|
+ console.log("MobileNet V1 模型加载完成");
|
|
|
+ } catch (e) {
|
|
|
+ console.warn("MobileNet V1 加载失败,尝试加载 V2:", e);
|
|
|
+ this.mobilenetModel = await this.mobilenet.load({
|
|
|
+ version: 2,
|
|
|
+ alpha: 1.0,
|
|
|
+ });
|
|
|
+ console.log("MobileNet V2 模型加载完成");
|
|
|
+ }
|
|
|
+
|
|
|
+ this.updateStatus("模型加载完成,应用准备就绪。");
|
|
|
+ this.isAppReady = true;
|
|
|
+ // 如果用户在模型加载前就选了图片,现在启用按钮
|
|
|
+ if (this.imagePreview.style.display !== "none") {
|
|
|
+ this.classifyButton.disabled = false;
|
|
|
+ this.trainingButtons.forEach((btn) => (btn.disabled = false));
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ console.error("模型加载或初始化失败:", error);
|
|
|
+ this.updateStatus("错误:模型加载失败。请检查网络或控制台。");
|
|
|
+ alert("模型加载失败,请检查网络连接或控制台错误。");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 提取特征的辅助函数
|
|
|
+ async getFeaturesFromImage(imageElement) {
|
|
|
+ // *** 使用 tf.tidy 来管理此函数内部的张量 ***
|
|
|
+ return tf.tidy(() => {
|
|
|
+ let imageTensor = this.tf.browser.fromPixels(imageElement);
|
|
|
+
|
|
|
+ // 调整大小并标准化
|
|
|
+ let expectedSize = IMAGE_SIZE;
|
|
|
+ try {
|
|
|
+ if (
|
|
|
+ this.mobilenetModel &&
|
|
|
+ this.mobilenetModel.inputs &&
|
|
|
+ this.mobilenetModel.inputs[0].shape
|
|
|
+ ) {
|
|
|
+ const shape = this.mobilenetModel.inputs[0].shape;
|
|
|
+ if (
|
|
|
+ shape &&
|
|
|
+ shape.length >= 3 &&
|
|
|
+ shape[1] != null &&
|
|
|
+ shape[2] != null &&
|
|
|
+ shape[1] === shape[2]
|
|
|
+ ) {
|
|
|
+ expectedSize = shape[1];
|
|
|
+ console.log(`检测到模型输入尺寸: ${expectedSize}x${expectedSize}`);
|
|
|
+ } else {
|
|
|
+ console.warn(
|
|
|
+ `无法从模型输入形状 ${JSON.stringify(
|
|
|
+ shape
|
|
|
+ )} 推断尺寸,使用默认值 ${IMAGE_SIZE}。`
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (e) {
|
|
|
+ console.warn(
|
|
|
+ `获取模型输入尺寸时出错,使用默认值 ${IMAGE_SIZE}。错误: ${e}`
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ let resizedTensor = tf.image
|
|
|
+ .resizeBilinear(imageTensor, [expectedSize, expectedSize], true)
|
|
|
+ .expandDims(0);
|
|
|
+ let normalizedTensor = resizedTensor.toFloat().div(127.5).sub(1.0);
|
|
|
+
|
|
|
+ // 特征提取
|
|
|
+ let logits = null;
|
|
|
+ try {
|
|
|
+ if (typeof this.mobilenetModel.infer === "function") {
|
|
|
+ logits = this.mobilenetModel.infer(normalizedTensor, true);
|
|
|
+ } else if (typeof this.mobilenetModel.predict === "function") {
|
|
|
+ logits = this.mobilenetModel.predict(normalizedTensor);
|
|
|
+ if (Array.isArray(logits)) {
|
|
|
+ logits = logits[0];
|
|
|
+ }
|
|
|
+ if (logits.rank > 2) {
|
|
|
+ const batchSize = logits.shape[0];
|
|
|
+ logits = logits.reshape([batchSize, -1]);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ throw new Error("无法识别的模型接口");
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ console.error("MobileNet 推理失败: ", error);
|
|
|
+ return null; // 返回 null 表示失败
|
|
|
+ }
|
|
|
+ // *** tidy 会自动清理 imageTensor, resizedTensor, normalizedTensor ***
|
|
|
+ // *** 但 logits 因为被返回,所以不会被这个 tidy 清理 ***
|
|
|
+ return logits;
|
|
|
+ }); // *** tidy 结束 ***
|
|
|
+ }
|
|
|
+
|
|
|
+ // 添加训练样本 (修改后)
|
|
|
+ async addImageForTraining(classIndex) {
|
|
|
+ if (!this.isAppReady) {
|
|
|
+ alert("模型尚未准备好,请稍候。");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (
|
|
|
+ !this.imagePreview ||
|
|
|
+ this.imagePreview.style.display === "none" ||
|
|
|
+ !this.imagePreview.src ||
|
|
|
+ this.imagePreview.src.endsWith("#")
|
|
|
+ ) {
|
|
|
+ alert("请先选择一张图片。");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 1. 获取特征 (内部使用 tidy)
|
|
|
+ const logits = await this.getFeaturesFromImage(this.imagePreview);
|
|
|
+
|
|
|
+ if (logits) {
|
|
|
+ // 2. 执行异步 KNN 操作 (不在 tidy 内部)
|
|
|
+ this.knn.addExample(logits, classIndex);
|
|
|
+ this.updateExampleCountUI(classIndex);
|
|
|
+
|
|
|
+ // 3. **手动释放不再需要的 logits** (因为它被 getFeaturesFromImage 的 tidy 返回了)
|
|
|
+ logits.dispose();
|
|
|
+ } else {
|
|
|
+ this.updateStatus(`错误:无法从图片提取特征用于训练物品 ${classIndex}`);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 分类当前图片 (修改后)
|
|
|
+ async classifyCurrentImage() {
|
|
|
+ if (!this.isAppReady) {
|
|
|
+ alert("模型尚未准备好,请稍候。");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (
|
|
|
+ !this.imagePreview ||
|
|
|
+ this.imagePreview.style.display === "none" ||
|
|
|
+ !this.imagePreview.src ||
|
|
|
+ this.imagePreview.src.endsWith("#")
|
|
|
+ ) {
|
|
|
+ alert("请先选择一张图片。");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ this.updateStatus("正在识别图片...");
|
|
|
+
|
|
|
+ const numClasses = this.knn.getNumClasses();
|
|
|
+ if (numClasses === 0) {
|
|
|
+ this.updateStatus("错误:您还没有训练任何物品类别!");
|
|
|
+ alert("请先训练至少一个物品类别再进行识别。");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 1. 获取特征 (内部使用 tidy)
|
|
|
+ const logits = await this.getFeaturesFromImage(this.imagePreview);
|
|
|
+
|
|
|
+ if (logits) {
|
|
|
+ // 2. 执行异步 KNN 操作 (不在 tidy 内部)
|
|
|
+ const res = await this.knn.predictClass(logits, TOPK);
|
|
|
+ this.updatePredictionUI(res);
|
|
|
+ this.updateStatus(
|
|
|
+ `识别完成。最可能的类别: 物品 ${res.classIndex} (${(
|
|
|
+ res.confidences[res.classIndex] * 100
|
|
|
+ ).toFixed(1)}%)`
|
|
|
+ );
|
|
|
+
|
|
|
+ // 3. **手动释放不再需要的 logits**
|
|
|
+ logits.dispose();
|
|
|
+ } else {
|
|
|
+ this.updateStatus("错误:无法从图片提取特征进行识别。");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 更新特定类别的样本计数UI (保持不变)
|
|
|
+ updateExampleCountUI(classIndex) {
|
|
|
+ const exampleCount = this.knn.getClassExampleCount();
|
|
|
+ if (this.infoTexts[classIndex]) {
|
|
|
+ const count = exampleCount[classIndex] || 0;
|
|
|
+ const currentText = this.infoTexts[classIndex].innerText;
|
|
|
+ const parts = currentText.split(" - ");
|
|
|
+ this.infoTexts[classIndex].innerText =
|
|
|
+ ` ${count} 个学习样例` + (parts.length > 1 ? ` - ${parts[1]}` : "");
|
|
|
+ this.infoTexts[classIndex].style.fontWeight = "normal";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 更新所有类别的预测UI
|
|
|
+ updatePredictionUI(predictionResult) {
|
|
|
+ const exampleCount = this.knn.getClassExampleCount();
|
|
|
+ for (let i = 0; i < NUM_CLASSES; i++) {
|
|
|
+ const infoText = this.infoTexts[i];
|
|
|
+ if (infoText) {
|
|
|
+ // 确保元素存在
|
|
|
+ const count = exampleCount[i] || 0;
|
|
|
+ // 使用 ?. 确保 confidences 存在且属性存在
|
|
|
+ const confidence =
|
|
|
+ predictionResult?.confidences?.[i] != null
|
|
|
+ ? (predictionResult.confidences[i] * 100).toFixed(1)
|
|
|
+ : "0.0"; // 默认为 0%
|
|
|
+ infoText.innerText = ` ${count} 个学习样例 - ${confidence}%`;
|
|
|
+
|
|
|
+ // 加粗预测结果
|
|
|
+ if (predictionResult.classIndex === i) {
|
|
|
+ infoText.style.fontWeight = "bold";
|
|
|
+ } else {
|
|
|
+ infoText.style.fontWeight = "normal";
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// 页面加载完成后启动应用
|
|
|
+window.addEventListener("load", () => new Main());
|