Safemotion Lib
Loading...
Searching...
No Matches
functional.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import numpy as np
8import torch
9from PIL import Image, ImageOps, ImageEnhance
10
11
12def to_tensor(pic):
13 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
14
15 See ``ToTensor`` for more details.
16
17 Args:
18 pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
19
20 Returns:
21 Tensor: Converted image.
22 """
23 if isinstance(pic, np.ndarray):
24 assert len(pic.shape) in (2, 3)
25 # handle numpy array
26 if pic.ndim == 2:
27 pic = pic[:, :, None]
28
29 img = torch.from_numpy(pic.transpose((2, 0, 1)))
30 # backward compatibility
31 if isinstance(img, torch.ByteTensor):
32 return img.float()
33 else:
34 return img
35
36 # handle PIL Image
37 if pic.mode == 'I':
38 img = torch.from_numpy(np.array(pic, np.int32, copy=False))
39 elif pic.mode == 'I;16':
40 img = torch.from_numpy(np.array(pic, np.int16, copy=False))
41 elif pic.mode == 'F':
42 img = torch.from_numpy(np.array(pic, np.float32, copy=False))
43 elif pic.mode == '1':
44 img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
45 else:
46 img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
47 # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
48 if pic.mode == 'YCbCr':
49 nchannel = 3
50 elif pic.mode == 'I;16':
51 nchannel = 1
52 else:
53 nchannel = len(pic.mode)
54 img = img.view(pic.size[1], pic.size[0], nchannel)
55 # put it from HWC to CHW format
56 # yikes, this transpose takes 80% of the loading time/CPU
57 img = img.transpose(0, 1).transpose(0, 2).contiguous()
58 if isinstance(img, torch.ByteTensor):
59 return img.float()
60 else:
61 return img
62
63
64def int_parameter(level, maxval):
65 """Helper function to scale `val` between 0 and maxval .
66 Args:
67 level: Level of the operation that will be between [0, `PARAMETER_MAX`].
68 maxval: Maximum value that the operation can have. This will be scaled to
69 level/PARAMETER_MAX.
70 Returns:
71 An int that results from scaling `maxval` according to `level`.
72 """
73 return int(level * maxval / 10)
74
75
76def float_parameter(level, maxval):
77 """Helper function to scale `val` between 0 and maxval.
78 Args:
79 level: Level of the operation that will be between [0, `PARAMETER_MAX`].
80 maxval: Maximum value that the operation can have. This will be scaled to
81 level/PARAMETER_MAX.
82 Returns:
83 A float that results from scaling `maxval` according to `level`.
84 """
85 return float(level) * maxval / 10.
86
87
89 return np.random.uniform(low=0.1, high=n)
90
91
92def autocontrast(pil_img, *args):
93 return ImageOps.autocontrast(pil_img)
94
95
96def equalize(pil_img, *args):
97 return ImageOps.equalize(pil_img)
98
99
100def posterize(pil_img, level, *args):
101 level = int_parameter(sample_level(level), 4)
102 return ImageOps.posterize(pil_img, 4 - level)
103
104
105def rotate(pil_img, level, *args):
106 degrees = int_parameter(sample_level(level), 30)
107 if np.random.uniform() > 0.5:
108 degrees = -degrees
109 return pil_img.rotate(degrees, resample=Image.BILINEAR)
110
111
112def solarize(pil_img, level, *args):
113 level = int_parameter(sample_level(level), 256)
114 return ImageOps.solarize(pil_img, 256 - level)
115
116
117def shear_x(pil_img, level, image_size):
118 level = float_parameter(sample_level(level), 0.3)
119 if np.random.uniform() > 0.5:
120 level = -level
121 return pil_img.transform(image_size,
122 Image.AFFINE, (1, level, 0, 0, 1, 0),
123 resample=Image.BILINEAR)
124
125
126def shear_y(pil_img, level, image_size):
127 level = float_parameter(sample_level(level), 0.3)
128 if np.random.uniform() > 0.5:
129 level = -level
130 return pil_img.transform(image_size,
131 Image.AFFINE, (1, 0, 0, level, 1, 0),
132 resample=Image.BILINEAR)
133
134
135def translate_x(pil_img, level, image_size):
136 level = int_parameter(sample_level(level), image_size[0] / 3)
137 if np.random.random() > 0.5:
138 level = -level
139 return pil_img.transform(image_size,
140 Image.AFFINE, (1, 0, level, 0, 1, 0),
141 resample=Image.BILINEAR)
142
143
144def translate_y(pil_img, level, image_size):
145 level = int_parameter(sample_level(level), image_size[1] / 3)
146 if np.random.random() > 0.5:
147 level = -level
148 return pil_img.transform(image_size,
149 Image.AFFINE, (1, 0, 0, 0, 1, level),
150 resample=Image.BILINEAR)
151
152
153# operation that overlaps with ImageNet-C's test set
154def color(pil_img, level, *args):
155 level = float_parameter(sample_level(level), 1.8) + 0.1
156 return ImageEnhance.Color(pil_img).enhance(level)
157
158
159# operation that overlaps with ImageNet-C's test set
160def contrast(pil_img, level, *args):
161 level = float_parameter(sample_level(level), 1.8) + 0.1
162 return ImageEnhance.Contrast(pil_img).enhance(level)
163
164
165# operation that overlaps with ImageNet-C's test set
166def brightness(pil_img, level, *args):
167 level = float_parameter(sample_level(level), 1.8) + 0.1
168 return ImageEnhance.Brightness(pil_img).enhance(level)
169
170
171# operation that overlaps with ImageNet-C's test set
172def sharpness(pil_img, level, *args):
173 level = float_parameter(sample_level(level), 1.8) + 0.1
174 return ImageEnhance.Sharpness(pil_img).enhance(level)
175
176
177augmentations_reid = [
178 autocontrast, equalize, posterize, shear_x, shear_y,
179 color, contrast, brightness, sharpness
180]
181
182augmentations = [
183 autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
184 translate_x, translate_y
185]
186
187augmentations_all = [
188 autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
189 translate_x, translate_y, color, contrast, brightness, sharpness
190]
translate_y(pil_img, level, image_size)
shear_x(pil_img, level, image_size)
shear_y(pil_img, level, image_size)
translate_x(pil_img, level, image_size)
solarize(pil_img, level, *args)
contrast(pil_img, level, *args)
brightness(pil_img, level, *args)
sharpness(pil_img, level, *args)
rotate(pil_img, level, *args)
posterize(pil_img, level, *args)
color(pil_img, level, *args)