網站首頁 編程語言 正文
隨著電子商務和在線網站的出現,圖像檢索在我們的日常生活中的應用一直在增加。
亞馬遜、阿里巴巴、Myntra等公司一直在大量利用圖像檢索技術。當然,只有當通常的信息檢索技術失敗時,圖像檢索才會開始工作。
背景
圖像檢索的基本本質是根據查詢圖像的特征從集合或數據庫中查找圖像。
大多數情況下,這種特征是圖像之間簡單的視覺相似性。在一個復雜的問題中,這種特征可能是兩幅圖像在風格上的相似性,甚至是互補性。
由于原始形式的圖像不會在基于像素的數據中反映這些特征,因此我們需要將這些像素數據轉換為一個潛空間,在該空間中,圖像的表示將反映這些特征。
一般來說,在潛空間中,任何兩個相似的圖像都會相互靠近,而不同的圖像則會相隔很遠。這是我們用來訓練我們的模型的基本管理規則。一旦我們這樣做,檢索部分只需搜索潛在空間,在給定查詢圖像表示的潛在空間中拾取最近的圖像。大多數情況下,它是在最近鄰搜索的幫助下完成的。
因此,我們可以將我們的方法分為兩部分:
- 圖像表現
- 搜索
我們將在Oxford 102 Flowers數據集上解決這兩個部分。
圖像表現
我們將使用一種叫做暹羅模型的東西,它本身并不是一種全新的模型,而是一種訓練模型的技術。大多數情況下,這是與triplet loss一起使用的。這個技術的基本組成部分是三元組。
三元組是3個獨立的數據樣本,比如A(錨點),B(陽性)和C(陰性);其中A和B相似或具有相似的特征(可能是同一類),而C與A和B都不相似。這三個樣本共同構成了訓練數據的一個單元——三元組。
注:任何圖像檢索任務的90%都體現在暹羅網絡、triplet loss和三元組的創建中。如果你成功地完成了這些,那么整個努力的成功或多或少是有保證的。
首先,我們將創建管道的這個組件——數據。下面我們將在PyTorch中創建一個自定義數據集和數據加載器,它將從數據集中生成三元組。
class TripletData(Dataset): def __init__(self, path, transforms, split="train"): self.path = path self.split = split # train or valid self.cats = 102 # number of categories self.transforms = transforms def __getitem__(self, idx): # our positive class for the triplet idx = str(idx%self.cats + 1) # choosing our pair of positive images (im1, im2) positives = os.listdir(os.path.join(self.path, idx)) im1, im2 = random.sample(positives, 2) # choosing a negative class and negative image (im3) negative_cats = [str(x+1) for x in range(self.cats)] negative_cats.remove(idx) negative_cat = str(random.choice(negative_cats)) negatives = os.listdir(os.path.join(self.path, negative_cat)) im3 = random.choice(negatives) im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3) im1 = self.transforms(Image.open(im1)) im2 = self.transforms(Image.open(im2)) im3 = self.transforms(Image.open(im3)) return [im1, im2, im3] # we'll put some value that we want since there can be far too many triplets possible # multiples of the number of images/ number of categories is a good choice def __len__(self): return self.cats*8 # Transforms train_transforms = transforms.Compose([ transforms.Resize((224,224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # Datasets and Dataloaders train_data = TripletData(PATH_TRAIN, train_transforms) val_data = TripletData(PATH_VALID, val_transforms) train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)
現在我們有了數據,讓我們轉到暹羅網絡。
暹羅網絡給人的印象是2個或3個模型,但是它本身是一個單一的模型。所有這些模型共享權重,即只有一個模型。
如前所述,將整個體系結構結合在一起的關鍵因素是triplet loss。triplet loss產生了一個目標函數,該函數迫使相似輸入對(錨點和正)之間的距離小于不同輸入對(錨點和負)之間的距離,并限定一定的閾值。
下面我們來看看triplet loss以及訓練管道實現。
class TripletLoss(nn.Module): def __init__(self, margin=1.0): super(TripletLoss, self).__init__() self.margin = margin def calc_euclidean(self, x1, x2): return (x1 - x2).pow(2).sum(1) # Distances in embedding space is calculated in euclidean def forward(self, anchor, positive, negative): distance_positive = self.calc_euclidean(anchor, positive) distance_negative = self.calc_euclidean(anchor, negative) losses = torch.relu(distance_positive - distance_negative + self.margin) return losses.mean() device = 'cuda' # Our base model model = models.resnet18().cuda() optimizer = optim.Adam(model.parameters(), lr=0.001) triplet_loss = TripletLoss() # Training for epoch in range(epochs): model.train() epoch_loss = 0.0 for data in tqdm(train_loader): optimizer.zero_grad() x1,x2,x3 = data e1 = model(x1.to(device)) e2 = model(x2.to(device)) e3 = model(x3.to(device)) loss = triplet_loss(e1,e2,e3) epoch_loss += loss loss.backward() optimizer.step() print("Train Loss: {}".format(epoch_loss.item())) class TripletLoss(nn.Module): def __init__(self, margin=1.0): super(TripletLoss, self).__init__() self.margin = margin def calc_euclidean(self, x1, x2): return (x1 - x2).pow(2).sum(1) # Distances in embedding space is calculated in euclidean def forward(self, anchor, positive, negative): distance_positive = self.calc_euclidean(anchor, positive) distance_negative = self.calc_euclidean(anchor, negative) losses = torch.relu(distance_positive - distance_negative + self.margin) return losses.mean() device = 'cuda' # Our base model model = models.resnet18().cuda() optimizer = optim.Adam(model.parameters(), lr=0.001) triplet_loss = TripletLoss() # Training for epoch in range(epochs): model.train() epoch_loss = 0.0 for data in tqdm(train_loader): optimizer.zero_grad() x1,x2,x3 = data e1 = model(x1.to(device)) e2 = model(x2.to(device)) e3 = model(x3.to(device)) loss = triplet_loss(e1,e2,e3) epoch_loss += loss loss.backward() optimizer.step() print("Train Loss: {}".format(epoch_loss.item()))
到目前為止,我們的模型已經經過訓練,可以將圖像轉換為一個嵌入空間。接下來,我們進入搜索部分。
搜索
我們可以很容易地使用Scikit Learn提供的最近鄰搜索。我們將探索新的更好的東西,而不是走簡單的路線。
我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的圖像,這種速度上的差異會變得更加明顯。
下面我們將演示如何在給定查詢圖像時,在存儲的圖像表示中搜索最近的圖像。
#!pip install faiss-gpu import faiss faiss_index = faiss.IndexFlatL2(1000) # build the index # storing the image representations im_indices = [] with torch.no_grad(): for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')): im = Image.open(f) im = im.resize((224,224)) im = torch.tensor([val_transforms(im).numpy()]).cuda() preds = model(im) preds = np.array([preds[0].cpu().numpy()]) faiss_index.add(preds) #add the representation to index im_indices.append(f) #store the image name to find it later on # Retrieval with a query image with torch.no_grad(): for f in os.listdir(PATH_TEST): # query/test image im = Image.open(os.path.join(PATH_TEST,f)) im = im.resize((224,224)) im = torch.tensor([val_transforms(im).numpy()]).cuda() test_embed = model(im).cpu().numpy() _, I = faiss_index.search(test_embed, 5) print("Retrieved Image: {}".format(im_indices[I[0][0]]))
這涵蓋了基于現代深度學習的圖像檢索,但不會使其變得太復雜。大多數檢索問題都可以通過這個基本管道解決。
原文鏈接:https://blog.csdn.net/woshicver/article/details/124030454
相關推薦
- 2022-09-06 python?numpy中array與pandas的DataFrame轉換方式_python
- 2022-06-22 Python?Tkinter?GUI編程實現Frame切換_python
- 2022-06-19 詳解Rainbond內置ServiceMesh微服務架構_云其它
- 2024-03-04 css鼠標移動上去才顯示滾動條
- 2022-05-06 golang導入私有倉庫報錯:“server response: not found:xxx: in
- 2022-03-31 C#算法之羅馬數字轉整數_C#教程
- 2022-05-27 C++帶頭雙向循環鏈表超詳細解析_C 語言
- 2022-06-18 C#如何實現dataGridView動態綁定數據_C#教程
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支